Merge branch 'ggml-org:master' into patch-2
This commit is contained in:
commit
605fb08c65
|
|
@ -104,3 +104,20 @@ OpenCL:
|
|||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-opencl.h
|
||||
- ggml/src/ggml-opencl/**
|
||||
- docs/backend/OPENCL.md
|
||||
Hexagon:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-hexagon.h
|
||||
- ggml/src/ggml-hexagon/**
|
||||
WebGPU:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-webgpu.h
|
||||
- ggml/src/ggml-webgpu/**
|
||||
OpenVINO:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-openvino.h
|
||||
- ggml/src/ggml-openvino/**
|
||||
- docs/backend/OPENVINO.md
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
name: AI review (issues)
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
find-related:
|
||||
if: github.event.action == 'opened'
|
||||
runs-on: [self-hosted, opencode]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Find related
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
OPENCODE_PERMISSION: |
|
||||
{
|
||||
"bash": {
|
||||
"*": "deny",
|
||||
"gh issue*": "allow",
|
||||
"gh search issues*": "allow"
|
||||
},
|
||||
"webfetch": "deny"
|
||||
}
|
||||
run: |
|
||||
rm AGENTS.md
|
||||
rm CLAUDE.md
|
||||
|
||||
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
|
||||
|
||||
Issue number: ${{ github.event.issue.number }}
|
||||
|
||||
Lookup the contents of the issue using the following 'gh' command:
|
||||
|
||||
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
|
||||
|
||||
Next, perform the following task and then post a SINGLE comment (if needed).
|
||||
|
||||
---
|
||||
|
||||
TASK : FIND RELATED ISSUES
|
||||
|
||||
Using the 'gh' CLI tool, search through existing issues on Github.
|
||||
Find related or similar issues to the newly created one and list them.
|
||||
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
|
||||
|
||||
Consider:
|
||||
1. Similar titles or descriptions
|
||||
2. Same error messages or symptoms
|
||||
3. Related functionality or components
|
||||
4. Similar feature requests
|
||||
|
||||
---
|
||||
|
||||
POSTING YOUR COMMENT:
|
||||
|
||||
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
|
||||
|
||||
- If no related issues were found, do NOT comment at all.
|
||||
- If related issues were found, include a section listing them with links using the following format:
|
||||
|
||||
[comment]
|
||||
This issue might be similar or related to the following issue(s):
|
||||
|
||||
- #[related_issue_number]: [brief description of how they are related]
|
||||
- #[related_issue_number]: [brief description of how they are related]
|
||||
...
|
||||
|
||||
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
|
||||
[/comment]
|
||||
|
||||
Remember:
|
||||
- Do not include the comment tags in your actual comment.
|
||||
- Post at most ONE comment combining all findings.
|
||||
- If you didn't find issues that are related enough, post nothing.
|
||||
- You have access only to the 'gh' CLI tool - don't try to use other tools.
|
||||
- If the output from a tool call is too long, try to limit down the search.
|
||||
"
|
||||
|
|
@ -46,7 +46,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-ios
|
||||
evict-old-files: 1d
|
||||
|
|
@ -124,7 +124,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-tvos
|
||||
evict-old-files: 1d
|
||||
|
|
@ -186,7 +186,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-swift
|
||||
evict-old-files: 1d
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-latest-sanitizer-${{ matrix.sanitizer }}
|
||||
evict-old-files: 1d
|
||||
|
|
|
|||
|
|
@ -97,19 +97,21 @@ jobs:
|
|||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-cpu-amx:
|
||||
runs-on: [self-hosted, Linux, CPU, AMX]
|
||||
# TODO: provision AMX-compatible machine
|
||||
#ggml-ci-cpu-amx:
|
||||
# runs-on: [self-hosted, Linux, CPU, AMX]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# TODO: provision AMD GPU machine
|
||||
# ggml-ci-amd-vulkan:
|
||||
# runs-on: [self-hosted, Linux, AMD]
|
||||
|
||||
|
|
@ -124,6 +126,7 @@ jobs:
|
|||
# vulkaninfo --summary
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# TODO: provision AMD GPU machine
|
||||
# ggml-ci-amd-rocm:
|
||||
# runs-on: [self-hosted, Linux, AMD]
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-24-vulkan-llvmpipe
|
||||
evict-old-files: 1d
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-arm64
|
||||
evict-old-files: 1d
|
||||
|
|
@ -105,7 +105,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-x64
|
||||
evict-old-files: 1d
|
||||
|
|
@ -141,7 +141,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-arm64-webgpu
|
||||
evict-old-files: 1d
|
||||
|
|
@ -195,7 +195,8 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
if: ${{ matrix.build != 's390x' && matrix.build != 'ppc64le' }}
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-cpu-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
|
@ -324,7 +325,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-24-webgpu
|
||||
evict-old-files: 1d
|
||||
|
|
@ -436,7 +437,7 @@ jobs:
|
|||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-hip
|
||||
evict-old-files: 1d
|
||||
|
|
@ -467,7 +468,7 @@ jobs:
|
|||
apt-get install -y build-essential git cmake libssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-musa
|
||||
evict-old-files: 1d
|
||||
|
|
@ -513,7 +514,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-sycl
|
||||
evict-old-files: 1d
|
||||
|
|
@ -562,7 +563,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-sycl-fp16
|
||||
evict-old-files: 1d
|
||||
|
|
@ -605,7 +606,7 @@ jobs:
|
|||
|
||||
- name: ccache
|
||||
if: runner.environment == 'github-hosted'
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-24-openvino-${{ matrix.variant }}-no-preset-v1
|
||||
evict-old-files: 1d
|
||||
|
|
@ -692,7 +693,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-${{ matrix.build }}
|
||||
variant: ccache
|
||||
|
|
@ -798,7 +799,7 @@ jobs:
|
|||
apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-latest-cuda
|
||||
evict-old-files: 1d
|
||||
|
|
@ -830,7 +831,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-cuda-${{ matrix.cuda }}
|
||||
variant: ccache
|
||||
|
|
@ -883,7 +884,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-sycl
|
||||
variant: ccache
|
||||
|
|
@ -944,7 +945,7 @@ jobs:
|
|||
& $clangPath.FullName --version
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ${{ github.job }}
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1068,7 +1069,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-x64-cpu-low-perf
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1094,7 +1095,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-low-perf
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1120,7 +1121,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-x64-cpu-high-perf
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1146,7 +1147,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-high-perf
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1172,7 +1173,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-high-perf-sve
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1198,7 +1199,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-kleidiai
|
||||
evict-old-files: 1d
|
||||
|
|
@ -1250,7 +1251,7 @@ jobs:
|
|||
sudo apt-get install -y cmake
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-kleidiai-graviton4
|
||||
evict-old-files: 1d
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: copilot-setup-steps
|
||||
evict-old-files: 1d
|
||||
|
|
@ -52,6 +52,6 @@ jobs:
|
|||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m venv .venv
|
||||
.venv/bin/activate
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements/requirements-all.txt -r tools/server/tests/requirements.txt
|
||||
pip install flake8 pyright pre-commit
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
name: HIP quality check
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/hip-quality-check.yml',
|
||||
'**/*.cu',
|
||||
'**/*.cuh'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: [
|
||||
'.github/workflows/hip-quality-check.yml',
|
||||
'**/*.cu',
|
||||
'**/*.cuh'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
GGML_NLOOP: 3
|
||||
GGML_N_THREADS: 1
|
||||
LLAMA_LOG_COLORS: 1
|
||||
LLAMA_LOG_PREFIX: 1
|
||||
LLAMA_LOG_TIMESTAMPS: 1
|
||||
|
||||
jobs:
|
||||
ubuntu-22-hip-quality-check:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-hip-quality-check
|
||||
evict-old-files: 1d
|
||||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build with Werror
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=Off \
|
||||
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
cd build
|
||||
make -j $(nproc)
|
||||
|
||||
- name: Check for major VGPR spills
|
||||
id: vgpr_check
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=On \
|
||||
-DCMAKE_HIP_FLAGS="" \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
cd build
|
||||
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
|
||||
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log
|
||||
|
|
@ -4,15 +4,17 @@ on:
|
|||
push:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- 'pyrightconfig.json'
|
||||
- 'ty.toml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
# - 'pyrightconfig.json'
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- 'pyrightconfig.json'
|
||||
- 'ty.toml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
# - 'pyrightconfig.json'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
|
|
@ -20,8 +22,8 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
python-type-check:
|
||||
runs-on: ubuntu-latest
|
||||
name: pyright type-check
|
||||
runs-on: ubuntu-slim
|
||||
name: python type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v6
|
||||
|
|
@ -29,10 +31,13 @@ jobs:
|
|||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
pip-install: -r requirements/requirements-all.txt
|
||||
- name: Type-check with Pyright
|
||||
uses: jakebailey/pyright-action@v2
|
||||
with:
|
||||
version: 1.1.382
|
||||
level: warning
|
||||
warnings: true
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.24
|
||||
# - name: Type-check with Pyright
|
||||
# uses: jakebailey/pyright-action@v2
|
||||
# with:
|
||||
# version: 1.1.382
|
||||
# level: warning
|
||||
# warnings: true
|
||||
- name: Type-check with ty
|
||||
run: |
|
||||
ty check --output-format=github
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-arm64
|
||||
evict-old-files: 1d
|
||||
|
|
@ -94,7 +94,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: macOS-latest-x64
|
||||
evict-old-files: 1d
|
||||
|
|
@ -153,7 +153,8 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
if: ${{ matrix.build != 's390x' }}
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-cpu-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
|
@ -204,7 +205,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-22-vulkan
|
||||
evict-old-files: 1d
|
||||
|
|
@ -269,7 +270,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-24-openvino-release-no-preset-v1
|
||||
evict-old-files: 1d
|
||||
|
|
@ -342,7 +343,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-cpu-${{ matrix.arch }}
|
||||
variant: ccache
|
||||
|
|
@ -403,7 +404,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-${{ matrix.backend }}-${{ matrix.arch }}
|
||||
variant: ccache
|
||||
|
|
@ -473,7 +474,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-cuda-${{ matrix.cuda }}
|
||||
variant: ccache
|
||||
|
|
@ -549,7 +550,7 @@ jobs:
|
|||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-sycl
|
||||
variant: ccache
|
||||
|
|
@ -629,7 +630,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
|
@ -739,7 +740,7 @@ jobs:
|
|||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: windows-latest-hip-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ matrix.name }}-x64
|
||||
evict-old-files: 1d
|
||||
|
|
|
|||
|
|
@ -178,6 +178,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
|
|||
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
|
||||
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
|
||||
|
||||
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
|
||||
|
||||
# Documentation
|
||||
|
||||
- Documentation is a community effort
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@ Fri Mar 6 11:39:45 2026
|
|||
+-----------------------------------------+------------------------+----------------------+
|
||||
```
|
||||
|
||||
## ggml-org/nemotron-3-super-120b-GGUF
|
||||
## ggml-org/Nemotron-3-Super-120B-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/nemotron-3-super-120b-GGUF
|
||||
Model: https://huggingface.co/ggml-org/Nemotron-3-Super-120B-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
|
|
@ -53,7 +53,6 @@ main: n_kv_max = 303104, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_
|
|||
| 8192 | 32 | 16 | 131584 | 171.066 | 766.21 | 10.774 | 47.52 | 181.840 | 723.62 |
|
||||
| 8192 | 32 | 32 | 263168 | 342.140 | 766.19 | 18.969 | 53.98 | 361.109 | 728.78 |
|
||||
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | n_ubatch | fa | test | t/s |
|
||||
|
|
@ -70,3 +69,49 @@ main: n_kv_max = 303104, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_
|
|||
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 @ d32768 | 19.45 ± 0.18 |
|
||||
|
||||
build: 04a65daab (8268)
|
||||
|
||||
## ggml-org/Nemotron-3-Nano-4B-GGUF
|
||||
|
||||
Model: https://huggingface.co/ggml-org/Nemotron-3-Nano-4B-GGUF
|
||||
|
||||
- `llama-batched-bench`
|
||||
|
||||
main: n_kv_max = 303104, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = 99, n_threads = 20, n_threads_batch = 20
|
||||
|
||||
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||
| 512 | 32 | 1 | 544 | 0.152 | 3371.61 | 0.597 | 53.64 | 0.748 | 726.90 |
|
||||
| 512 | 32 | 2 | 1088 | 0.319 | 3208.68 | 0.857 | 74.66 | 1.176 | 924.89 |
|
||||
| 512 | 32 | 4 | 2176 | 0.720 | 2843.56 | 1.323 | 96.78 | 2.043 | 1065.18 |
|
||||
| 512 | 32 | 8 | 4352 | 1.428 | 2867.96 | 2.311 | 110.76 | 3.739 | 1163.82 |
|
||||
| 512 | 32 | 16 | 8704 | 2.857 | 2866.94 | 4.203 | 121.82 | 7.060 | 1232.82 |
|
||||
| 512 | 32 | 32 | 17408 | 5.709 | 2869.76 | 7.964 | 128.58 | 13.673 | 1273.14 |
|
||||
| 4096 | 32 | 1 | 4128 | 1.458 | 2809.76 | 0.605 | 52.92 | 2.062 | 2001.52 |
|
||||
| 4096 | 32 | 2 | 8256 | 2.905 | 2819.95 | 0.875 | 73.12 | 3.780 | 2183.95 |
|
||||
| 4096 | 32 | 4 | 16512 | 5.790 | 2829.74 | 1.361 | 94.07 | 7.151 | 2309.17 |
|
||||
| 4096 | 32 | 8 | 33024 | 11.598 | 2825.32 | 2.378 | 107.65 | 13.976 | 2362.89 |
|
||||
| 4096 | 32 | 16 | 66048 | 23.208 | 2823.88 | 4.348 | 117.76 | 27.556 | 2396.89 |
|
||||
| 4096 | 32 | 32 | 132096 | 46.515 | 2817.85 | 8.279 | 123.69 | 54.794 | 2410.79 |
|
||||
| 8192 | 32 | 1 | 8224 | 2.950 | 2776.95 | 0.617 | 51.89 | 3.567 | 2305.75 |
|
||||
| 8192 | 32 | 2 | 16448 | 5.921 | 2767.32 | 0.896 | 71.45 | 6.816 | 2413.05 |
|
||||
| 8192 | 32 | 4 | 32896 | 11.842 | 2767.21 | 1.401 | 91.34 | 13.243 | 2484.03 |
|
||||
| 8192 | 32 | 8 | 65792 | 23.726 | 2762.17 | 2.461 | 104.03 | 26.187 | 2512.38 |
|
||||
| 8192 | 32 | 16 | 131584 | 47.777 | 2743.43 | 4.577 | 111.86 | 52.354 | 2513.36 |
|
||||
| 8192 | 32 | 32 | 263168 | 96.691 | 2711.16 | 8.772 | 116.73 | 105.463 | 2495.36 |
|
||||
|
||||
- `llama-bench`
|
||||
|
||||
| model | size | params | backend | n_ubatch | fa | test | t/s |
|
||||
| ----------------------- | ---------: | ---------: | ---------- | -------: | -: | --------------: | -------------------: |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | pp2048 | 2761.90 ± 19.31 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | tg32 | 52.85 ± 0.12 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | pp2048 @ d4096 | 2687.07 ± 21.84 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | tg32 @ d4096 | 52.32 ± 0.23 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | pp2048 @ d8192 | 2564.52 ± 57.69 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | tg32 @ d8192 | 51.27 ± 0.34 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | pp2048 @ d16384 | 2334.02 ± 37.83 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | tg32 @ d16384 | 49.71 ± 0.14 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | pp2048 @ d32768 | 2041.46 ± 40.45 |
|
||||
| nemotron 4B Q8_0 | 3.94 GiB | 3.97 B | CUDA | 2048 | 1 | tg32 @ d32768 | 46.71 ± 0.13 |
|
||||
|
||||
build: 1bbec6a75 (8382)
|
||||
|
|
|
|||
12
ci/run.sh
12
ci/run.sh
|
|
@ -25,7 +25,13 @@
|
|||
# # with KLEIDIAI support
|
||||
# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with OPENVINO support
|
||||
# # with BLAS support
|
||||
# GG_BUILD_BLAS=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# with BLAS support (custom vendor)
|
||||
# GG_BUILD_BLAS=1 GG_BUILD_BLAS_VENDOR=Intel10_64lp bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# with OPENVINO support
|
||||
# GG_BUILD_OPENVINO=1 GG_BUILD_LOW_PERF=1 GGML_OPENVINO_DEVICE=CPU bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
|
||||
|
|
@ -169,6 +175,10 @@ if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
|||
-DBUILD_SHARED_LIBS=OFF"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_BLAS} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=${GG_BUILD_BLAS_VENDOR:-OpenBLAS}"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_OPENVINO} ]; then
|
||||
if [ -z ${OpenVINO_DIR} ]; then
|
||||
echo "OpenVINO_DIR not found, please install OpenVINO via archives and enable it by:"
|
||||
|
|
|
|||
|
|
@ -1830,23 +1830,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar"}, "GRAMMAR",
|
||||
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
|
||||
"BNF-like grammar to constrain generations (see samples in grammars/ dir)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = value;
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--grammar-file"}, "FNAME",
|
||||
"file to read grammar from",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = read_file(value);
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"-j", "--json-schema"}, "SCHEMA",
|
||||
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
|
|
@ -1863,7 +1863,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(schema)
|
||||
);
|
||||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
|
|
@ -3115,6 +3115,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.chat_template = read_file(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--skip-chat-parsing"},
|
||||
{"--no-skip-chat-parsing"},
|
||||
string_format(
|
||||
"force a pure content parser, even if a Jinja template is specified; model will output everything "
|
||||
"in the content section, including any reasoning and/or tool calls (default: disabled)"
|
||||
),
|
||||
[](common_params & params, bool value) {
|
||||
params.force_pure_content_parser = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SKIP_CHAT_PARSING"));
|
||||
add_opt(common_arg(
|
||||
{"--prefill-assistant"},
|
||||
{"--no-prefill-assistant"},
|
||||
|
|
@ -3483,7 +3494,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
|
|
@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
|
|||
|
||||
namespace autoparser {
|
||||
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
|
||||
p(p),
|
||||
inputs(inputs),
|
||||
reasoning_parser(p.eps()) {}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs) {
|
||||
const struct generation_params & inputs) {
|
||||
// Run differential analysis to extract template structure
|
||||
struct autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
|
|
@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser) {
|
||||
// Build the parser using the analysis results
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.parser = parser.save();
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
bool has_tools =
|
||||
|
|
@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
|
||||
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// If the template uses Python dict format (single-quoted strings in JSON structures),
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", p.quoted_string());
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = inputs.enable_thinking;
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
|
||||
auto parser = p.eps();
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
return ctx.reasoning_parser + p.space() + p.choice({
|
||||
parser = ctx.reasoning_parser + p.space() + p.choice({
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
parser = tools.build_parser(ctx);
|
||||
} else {
|
||||
parser = content.build_parser(ctx);
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
return tools.build_parser(ctx);
|
||||
}
|
||||
|
||||
return content.build_parser(ctx);
|
||||
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
|
||||
return parser;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
|||
return p.eps();
|
||||
}
|
||||
|
||||
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
|
||||
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
|
||||
|
||||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
// However, since we might have incorrectly detected the open/close pattern,
|
||||
// we admit an optional starting marker
|
||||
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
// Both use the same tag-based pattern if markers are available
|
||||
if (!start.empty() && !end.empty()) {
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end);
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
}
|
||||
} else if (mode == reasoning_mode::DELIMITER) {
|
||||
return p.optional(p.reasoning(p.until(end)) + end);
|
||||
}
|
||||
|
||||
return p.eps();
|
||||
|
|
@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
|
||||
|
|
@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <cctype>
|
||||
#include <numeric>
|
||||
|
|
@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
|
|||
result.suffix = "";
|
||||
// pick prefix = all as representation
|
||||
}
|
||||
|
||||
// When left has no unique content (result.left is empty), left is entirely
|
||||
// shared with right. The simultaneous prefix/suffix segment matching can
|
||||
// incorrectly consume trailing segments of left as suffix when those same
|
||||
// segments also appear at the end of right (e.g. "\n" at the end of both
|
||||
// the shared content and the generation prompt). This rotates the diff.
|
||||
// Fix: if left is a prefix of right, enforce that directly.
|
||||
if (result.left.empty() && !result.right.empty() &&
|
||||
left.size() <= right.size() &&
|
||||
right.substr(0, left.size()) == left) {
|
||||
result.prefix = left;
|
||||
result.suffix = "";
|
||||
result.right = right.substr(left.size());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -291,10 +308,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
|
|||
return result;
|
||||
}
|
||||
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start) {
|
||||
auto parser = prs;
|
||||
if (!inputs.generation_prompt.empty()) {
|
||||
size_t end_pos = inputs.generation_prompt.size();
|
||||
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
|
||||
end_pos = inputs.generation_prompt.find(reasoning_start);
|
||||
}
|
||||
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
|
||||
parser = p.literal(cut_genprompt) + parser;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
|
||||
templates_params tmpl_params;
|
||||
generation_params tmpl_params;
|
||||
tmpl_params.messages = params.messages;
|
||||
tmpl_params.tools = params.tools;
|
||||
tmpl_params.add_generation_prompt = params.add_generation_prompt;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
|
@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
|
|||
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
|
||||
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
|
||||
|
||||
// Wrap parser with generation prompt parser
|
||||
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
|
||||
const common_peg_parser & prs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::string & reasoning_start = {});
|
||||
namespace autoparser {
|
||||
|
||||
// Apply a template with the given parameters, returning the rendered string (empty on failure)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ namespace autoparser {
|
|||
// High-level params for parser generation
|
||||
// ============================================================================
|
||||
|
||||
struct templates_params {
|
||||
struct generation_params {
|
||||
json messages;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
|
@ -62,6 +62,7 @@ struct templates_params {
|
|||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::string generation_prompt;
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
|
|
@ -77,11 +78,7 @@ struct templates_params {
|
|||
// Reasoning handling mode (derived from R1-R3 comparisons)
|
||||
enum class reasoning_mode {
|
||||
NONE, // No reasoning markers detected
|
||||
TAG_BASED, // Standard tag-based: <think>...</think>
|
||||
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
|
||||
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
|
||||
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
|
||||
// with both opened and closed tag for disabled thinking
|
||||
TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
|
||||
TOOLS_ONLY // Only reason on tool calls, not on normal content
|
||||
};
|
||||
|
||||
|
|
@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
|
|||
return os << "NONE";
|
||||
case reasoning_mode::TAG_BASED:
|
||||
return os << "TAG_BASED";
|
||||
case reasoning_mode::DELIMITER:
|
||||
return os << "DELIMITER";
|
||||
case reasoning_mode::FORCED_OPEN:
|
||||
return os << "FORCED_OPEN";
|
||||
case reasoning_mode::FORCED_CLOSED:
|
||||
return os << "FORCED_CLOSED";
|
||||
case reasoning_mode::TOOLS_ONLY:
|
||||
return os << "TOOLS_ONLY";
|
||||
default:
|
||||
|
|
@ -184,7 +175,6 @@ struct tool_format_analysis {
|
|||
|
||||
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
|
||||
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
|
||||
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
|
||||
|
||||
std::string function_field = "function";
|
||||
std::string name_field = "name";
|
||||
|
|
@ -225,12 +215,12 @@ struct analyze_content;
|
|||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const templates_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
|
|||
|
||||
analyze_reasoning() = default;
|
||||
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
|
||||
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
|
|
@ -381,7 +372,7 @@ struct autoparser {
|
|||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const templates_params & inputs) const;
|
||||
common_peg_arena build_parser(const generation_params & inputs) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
|
|
@ -395,10 +386,10 @@ struct autoparser {
|
|||
class peg_generator {
|
||||
public:
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs);
|
||||
const struct generation_params & inputs);
|
||||
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const struct generation_params & inputs,
|
||||
const autoparser & autoparser);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
|
@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
|||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
|
||||
tmpl.src.find("reasoning_content") == std::string::npos &&
|
||||
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
|
||||
analysis.reasoning.mode == reasoning_mode::NONE) {
|
||||
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN;
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<think>";
|
||||
analysis.reasoning.end = "</think>";
|
||||
analysis.preserved_tokens.push_back("<think>");
|
||||
|
|
@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
|||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
|
|
@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
|
|||
}
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
} else {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_whitespace(result.tags["pre"]);
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
end = result.tags["post"];
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
|
|||
const auto & diff = comparison->diff;
|
||||
|
||||
std::string left_trimmed = trim_whitespace(diff.left);
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
std::string right_trimmed = trim_whitespace(diff.right);
|
||||
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = right_trimmed;
|
||||
mode = reasoning_mode::FORCED_OPEN;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
} else if (right_trimmed.empty() && !diff.left.empty()) {
|
||||
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
|
||||
if (end.empty()) {
|
||||
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = left_trimmed;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::DELIMITER;
|
||||
}
|
||||
|
||||
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
|
||||
// but enable_thinking=true produces only the start marker
|
||||
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
|
||||
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(start) + p.space() + p.literal(end) + p.rest();
|
||||
});
|
||||
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
|
||||
});
|
||||
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
|
||||
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
|
||||
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
|
||||
if (!diff.left.empty() && !diff.right.empty()) {
|
||||
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
|
||||
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
|
||||
if (seg_A.size() == 1 && seg_B.size() == 1) {
|
||||
mode = reasoning_mode::FORCED_CLOSED;
|
||||
start = seg_B[0].value;
|
||||
end = seg_A[0].value;
|
||||
}
|
||||
}
|
||||
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() {
|
|||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = result.tags["post"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
}
|
||||
}
|
||||
|
|
@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
|
|||
return;
|
||||
}
|
||||
|
||||
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES };
|
||||
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
|
||||
auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
|
||||
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
|
||||
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
|
||||
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
|
||||
});
|
||||
auto result = parser.parse_anywhere_and_extract(haystack);
|
||||
if (!result.result.success()) {
|
||||
return json_quote_style::NONE;
|
||||
}
|
||||
return result.tags.count("sq") && !result.tags["sq"].empty()
|
||||
? json_quote_style::SINGLE_QUOTES
|
||||
: json_quote_style::DOUBLE_QUOTES;
|
||||
return result.result.success();
|
||||
};
|
||||
|
||||
auto fun_quote = in_json_haystack(fun_name_needle);
|
||||
auto arg_quote = in_json_haystack(arg_name_needle);
|
||||
|
||||
if (fun_quote != json_quote_style::NONE) {
|
||||
if (fun_quote) {
|
||||
// no need to check further, we're in JSON land
|
||||
format.mode = tool_format::JSON_NATIVE;
|
||||
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else if (arg_quote != json_quote_style::NONE) {
|
||||
} else if (arg_quote) {
|
||||
format.mode = tool_format::TAG_WITH_JSON;
|
||||
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
|
||||
} else {
|
||||
format.mode = tool_format::TAG_WITH_TAGGED;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
|||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
pending_tool_call.reset();
|
||||
}
|
||||
|
||||
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
|
||||
if (!result.reasoning_content.empty()) {
|
||||
bool all_whitespace = true;
|
||||
for (char c : result.reasoning_content) {
|
||||
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
|
||||
all_whitespace = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_whitespace) {
|
||||
result.reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
|
|
|
|||
405
common/chat.cpp
405
common/chat.cpp
|
|
@ -1,5 +1,6 @@
|
|||
#include "chat.h"
|
||||
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
|
|
@ -22,6 +23,7 @@
|
|||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
|
@ -760,7 +762,7 @@ static void foreach_parameter(const json &
|
|||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
|
|
@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
|
|||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
|
||||
|
|
@ -876,8 +878,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
// Ministral wants to emit json surrounded by code fences
|
||||
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
|
||||
<< "```";
|
||||
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Tool call parser
|
||||
|
|
@ -897,12 +899,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
|
||||
|
||||
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
|
||||
inputs, "[THINK]");
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return reasoning << p.content(p.rest());
|
||||
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -928,22 +931,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
for (auto msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
msg["thinking"] = msg.at("reasoning_content");
|
||||
if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
|
||||
msg.erase("content");
|
||||
}
|
||||
}
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
|
@ -969,45 +969,32 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
const std::string END = "<|end|>";
|
||||
const std::string START = "<|start|>";
|
||||
const std::string MESSAGE = "<|message|>";
|
||||
const std::string CHANNEL = "<|channel|>";
|
||||
const std::string CONSTRAIN = "<|constrain|>";
|
||||
const std::string START_ASSISTANT = START + "assistant";
|
||||
const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis";
|
||||
const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary";
|
||||
const std::string CHANNEL_FINAL = CHANNEL + "final";
|
||||
auto start = p.rule("start", p.literal("<|start|>assistant"));
|
||||
auto end = p.rule("end", p.literal("<|end|>"));
|
||||
auto content = p.rule("message-content", p.until("<|end|>"));
|
||||
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
|
||||
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
|
||||
|
||||
auto the_end = END | p.end();
|
||||
auto analysis = p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
|
||||
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
auto any = p.rule("any", preamble | analysis);
|
||||
|
||||
const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE;
|
||||
auto segment_content = p.until(END);
|
||||
auto analysis_segment = extract_reasoning ?
|
||||
p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end :
|
||||
p.content(analysis_header + p.until(END) + the_end);
|
||||
if (has_response_format) {
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
|
||||
auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE });
|
||||
auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) });
|
||||
auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE +
|
||||
p.content(segment_content) + the_end);
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
auto final_header = p.literal(CHANNEL_FINAL);
|
||||
auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content);
|
||||
return p.optional(analysis_segment) + final_header + constraint + MESSAGE +
|
||||
p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment });
|
||||
auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end();
|
||||
|
||||
// Tool call parser
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
auto tool_choice = p.choice();
|
||||
|
||||
|
|
@ -1016,42 +1003,39 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
std::string name = function.at("name");
|
||||
const auto & params = function.at("parameters");
|
||||
|
||||
// Tool call can appear as:
|
||||
// 1. In role header: " to=functions.NAME<|channel|>..."
|
||||
// 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..."
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
|
||||
auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS);
|
||||
auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content);
|
||||
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
||||
auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
|
||||
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
|
||||
|
||||
// Pattern 1: recipient in role header
|
||||
// " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS"
|
||||
auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args);
|
||||
// recipient in role header
|
||||
// <|start|>assistant to=functions.NAME<|channel|>(commentary|analysis)[constraint]<|message|>ARGS
|
||||
auto tool_in_role = p.tool(p.tool_open(func_name + channel + constraint + p.literal("<|message|>")) + args);
|
||||
|
||||
// Pattern 2: recipient in channel header
|
||||
// "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS"
|
||||
auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args);
|
||||
// recipient in channel header
|
||||
// <|channel|>(commentary|analysis) to=functions.NAME[constraint]<|message|>ARGS
|
||||
auto tool_in_channel = p.tool(p.tool_open(channel + func_name + constraint + p.literal("<|message|>")) + args);
|
||||
|
||||
tool_choice |= tool_in_role | tool_in_channel;
|
||||
tool_choice |= p.rule("tool-" + name, tool_in_role | tool_in_channel);
|
||||
});
|
||||
|
||||
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_call = p.trigger_rule("tool-call", tool_choice);
|
||||
|
||||
auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT));
|
||||
auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end());
|
||||
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
|
||||
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
|
||||
}
|
||||
|
||||
return p.choice({ p.trigger_rule("single-tool", tool_call), p.trigger_rule("tools", p.one_or_more(segment) + tool_call) });
|
||||
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
|
||||
inputs, "<|channel|>");
|
||||
}
|
||||
|
||||
return contents;
|
||||
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
|
||||
inputs, "<|channel|>");
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
|
|
@ -1062,10 +1046,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
"(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" }
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" },
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" }
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -1074,7 +1057,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
|
||||
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
|
|
@ -1095,13 +1078,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// Build content parser for >>>all\n{content}
|
||||
// When tools are present, content stops before the next ">>>" (tool call)
|
||||
// When no tools, content goes until end
|
||||
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest());
|
||||
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
|
||||
auto content_until_end = p.literal("all\n") + p.content(p.rest());
|
||||
|
||||
// If no tools or tool_choice is NONE, just parse content
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// When no tools, just match the prefix and capture everything after
|
||||
return content_until_end + p.end();
|
||||
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
|
|
@ -1113,7 +1096,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
|
||||
// Tool format: >>>function_name\n{json_args}
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
|
||||
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
|
||||
);
|
||||
|
||||
|
|
@ -1124,17 +1107,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
|
||||
auto content_and_tools = content_until_tool + tools_only;
|
||||
|
||||
auto ret = p.eps();
|
||||
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
ret = p.choice({ content_and_tools, tools_only }) + p.end();
|
||||
} else {
|
||||
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
}
|
||||
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
|
||||
} else if (inputs.parallel_tool_calls) {
|
||||
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
} else {
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
}
|
||||
if (inputs.parallel_tool_calls) {
|
||||
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
|
||||
}
|
||||
auto content_and_tool = content_until_tool + tool_choice;
|
||||
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1164,14 +1150,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
|
||||
// The ID contains both the function name and an incrementing counter
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
|
|
@ -1186,6 +1170,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Kimi K2 Thinking format:
|
||||
// - Reasoning: <think>{reasoning}</think>
|
||||
|
|
@ -1197,16 +1193,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
// <|tool_calls_section_end|>
|
||||
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
|
||||
|
||||
// Tool call markers
|
||||
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
const std::string SECTION_END = "<|tool_calls_section_end|>";
|
||||
const std::string CALL_BEGIN = "<|tool_call_begin|>";
|
||||
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
|
||||
const std::string CALL_END = "<|tool_call_end|>";
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
// Tool call markers
|
||||
auto end = p.end();
|
||||
|
||||
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
|
||||
|
|
@ -1218,7 +1205,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
|
||||
// Content only parser (no tools)
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
|
||||
inputs, THINK_START);
|
||||
}
|
||||
|
||||
// Build tool call parsers for each available function
|
||||
|
|
@ -1254,7 +1242,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
|
||||
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
|
||||
|
||||
return reasoning + content_before_tools + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
|
||||
inputs, THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1395,7 +1384,7 @@ static common_chat_params common_chat_params_init_mirothinker(const common_chat_
|
|||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
|
|
@ -1414,13 +1403,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
|
||||
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
|
|
@ -1429,7 +1420,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
|
||||
THINK_START);
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
|
|
@ -1441,7 +1433,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
|
||||
auto content = p.content(p.until(TOOL_CALL_START));
|
||||
|
||||
return reasoning + content + tool_calls + end;
|
||||
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
|
||||
THINK_START);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1467,7 +1460,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
|
||||
static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
const autoparser::generation_params & inputs) {
|
||||
|
||||
common_chat_params data;
|
||||
|
||||
|
|
@ -1481,9 +1474,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto ret = p.eps();
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
|
|
@ -1506,13 +1500,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
|
||||
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
|
||||
|
||||
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
|
||||
} else {
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
ret = p.content(p.rest());
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return p.content(p.rest());
|
||||
|
||||
return wrap_for_generation_prompt(p, ret, inputs);
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
|
@ -1607,69 +1602,10 @@ static json common_chat_extra_context() {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
workaround::func_args_not_string(params.messages);
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||
// params.parallel_tool_calls = false;
|
||||
// } else {
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
//}
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
||||
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
||||
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
||||
|
|
@ -1718,14 +1654,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
// GigaChatV3 format detection
|
||||
if (src.find("<|role_sep|>") != std::string::npos &&
|
||||
src.find("<|message_sep|>") != std::string::npos &&
|
||||
src.find("<|function_call|>") == std::string::npos
|
||||
) {
|
||||
src.find("<|function_call|>") == std::string::npos) {
|
||||
LOG_DBG("Using specialized template: GigaChatV3\n");
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
workaround::system_message_not_supported(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_tool_calls) {
|
||||
// some templates will require the content field in tool call messages
|
||||
// to still be non-null, this puts an empty string everywhere where the
|
||||
// content field is null
|
||||
workaround::requires_non_null_content(params.messages);
|
||||
}
|
||||
|
||||
if (tmpl.original_caps().supports_object_arguments) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
params.extra_context[el.first] = json::parse(el.second);
|
||||
}
|
||||
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
|
||||
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||
|
||||
if (params.tools.is_array()) {
|
||||
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||
throw std::runtime_error("Cannot specify grammar with tools");
|
||||
}
|
||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||
LOG_WRN(
|
||||
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
||||
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.force_pure_content) {
|
||||
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
|
||||
});
|
||||
data.parser = parser.save();
|
||||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
LOG_DBG("%s: using differential autoparser\n", __func__);
|
||||
struct autoparser::autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
|
@ -1733,13 +1760,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
|
||||
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
|
||||
// but forces <think> open when thinking is enabled)
|
||||
auto_params.thinking_forced_open =
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
arena.load(auto_params.parser);
|
||||
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
|
||||
|
|
@ -1837,14 +1862,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
const std::string effective_input = params.generation_prompt.empty()
|
||||
? input
|
||||
: params.generation_prompt + input;
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
|
||||
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||
if (params.debug) {
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
common_peg_parse_context ctx(effective_input, flags);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
if (result.fail()) {
|
||||
|
|
@ -1864,7 +1893,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
return msg;
|
||||
}
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
|
||||
input.substr(result.end));
|
||||
effective_input.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
|
|||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
struct templates_params;
|
||||
struct generation_params;
|
||||
} // namespace autoparser
|
||||
|
||||
struct common_chat_tool_call {
|
||||
|
|
@ -204,6 +204,7 @@ struct common_chat_templates_inputs {
|
|||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
bool force_pure_content = false;
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
|
|
@ -211,7 +212,7 @@ struct common_chat_params {
|
|||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool supports_thinking = false;
|
||||
std::string thinking_start_tag; // e.g., "<think>"
|
||||
std::string thinking_end_tag; // e.g., "</think>"
|
||||
|
|
@ -228,14 +229,14 @@ struct common_chat_parser_params {
|
|||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::string generation_prompt;
|
||||
bool parse_tool_calls = true;
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
format = chat_params.format;
|
||||
generation_prompt = chat_params.generation_prompt;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -301,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
|||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
|
|
|
|||
|
|
@ -1067,7 +1067,7 @@ common_init_result::common_init_result(common_params & params) :
|
|||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load and optionally apply lora adapters (must be loaded before context creation)
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "ggml.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
|
|
@ -178,6 +180,43 @@ enum common_speculative_type {
|
|||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// Grammar type enumeration
|
||||
enum common_grammar_type {
|
||||
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
|
||||
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
|
||||
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
|
||||
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
|
||||
};
|
||||
|
||||
// Grammar variant struct with type and grammar string
|
||||
struct common_grammar {
|
||||
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
|
||||
std::string grammar;
|
||||
|
||||
// Default constructor - no grammar
|
||||
common_grammar() = default;
|
||||
|
||||
// Constructor with type and grammar string
|
||||
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
|
||||
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
|
||||
}
|
||||
|
||||
// Check if a grammar is set
|
||||
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
|
||||
};
|
||||
|
||||
// Returns the raw grammar string, or empty string if no grammar is set.
|
||||
inline const std::string & common_grammar_value(const common_grammar & g) {
|
||||
return g.grammar;
|
||||
}
|
||||
|
||||
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
|
||||
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
|
||||
inline bool common_grammar_needs_prefill(const common_grammar & g) {
|
||||
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|
||||
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
|
||||
}
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
|
@ -228,7 +267,7 @@ struct common_params_sampling {
|
|||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
|
@ -236,10 +275,15 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// The assistant generation prompt already prefilled into the prompt.
|
||||
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
|
||||
// to determine the reasoning budget sampler's initial state.
|
||||
// Only applied when the grammar is of output-format or tool-calls type.
|
||||
std::string generation_prompt;
|
||||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
bool reasoning_budget_activate_immediately = false;
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
|
|
@ -544,6 +588,7 @@ struct common_params {
|
|||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = true; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
bool force_pure_content_parser = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
|
||||
int reasoning_budget = -1;
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ std::map<std::string, bool> caps::to_map() const {
|
|||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
{"supports_system_role", supports_system_role},
|
||||
{"supports_preserve_reasoning", supports_preserve_reasoning},
|
||||
{"supports_object_arguments", supports_object_arguments},
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -158,9 +159,9 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: single tool support");
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: single tool with object arguments support");
|
||||
|
||||
// case: tools support: single call
|
||||
// case: tools support: single call with object arguments
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
|
|
@ -226,9 +227,7 @@ caps caps_get(jinja::program & prog) {
|
|||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
if (!success) {
|
||||
result.supports_tool_calls = false;
|
||||
result.supports_tools = false;
|
||||
return;
|
||||
return; // Nothing can be inferred
|
||||
}
|
||||
|
||||
auto & tool_name = tools->at(0)->at("function")->at("name");
|
||||
|
|
@ -242,16 +241,117 @@ caps caps_get(jinja::program & prog) {
|
|||
caps_print_stats(tool_calls, "messages[1].tool_calls");
|
||||
if (!tool_calls->stats.used) {
|
||||
result.supports_tool_calls = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto & tool_arg = tool_calls->at(0)->at("function")->at("arguments")->at("arg");
|
||||
caps_print_stats(tool_arg, "messages[1].tool_calls[0].function.arguments.arg");
|
||||
if (tool_arg->stats.used) {
|
||||
result.supports_object_arguments = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
if (!result.supports_object_arguments) {
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: single tool with string arguments support");
|
||||
|
||||
// case: tools support: single call with string arguments
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"},
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", ""}, // Some templates expect content to be empty with tool calls
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call00001"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"arguments", R"({"arg": "value"})"}
|
||||
}}
|
||||
}
|
||||
})}
|
||||
},
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"content", "Tool response"},
|
||||
{"tool_call_id", "call00001"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "The tool response was 'tool response'"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"},
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
{
|
||||
{"name", "tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"description", "Tool description"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"arg", {
|
||||
{"type", "string"},
|
||||
{"description", "Arg description"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({ "arg" })},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
if (!success) {
|
||||
result.supports_tool_calls = false;
|
||||
result.supports_tools = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto & tool_name = tools->at(0)->at("function")->at("name");
|
||||
caps_print_stats(tool_name, "tools[0].function.name");
|
||||
caps_print_stats(tools, "tools");
|
||||
if (!tool_name->stats.used) {
|
||||
result.supports_tools = false;
|
||||
}
|
||||
|
||||
auto & tool_calls = messages->at(1)->at("tool_calls");
|
||||
caps_print_stats(tool_calls, "messages[1].tool_calls");
|
||||
if (!tool_calls->stats.used) {
|
||||
result.supports_tool_calls = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support");
|
||||
|
||||
// case: tools support: parallel calls
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
json args = json(R"({"arg": "value"})");
|
||||
if (result.supports_object_arguments) {
|
||||
args = json{{"arg", "value"}};
|
||||
}
|
||||
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
|
|
@ -267,9 +367,7 @@ caps caps_get(jinja::program & prog) {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"arguments", {
|
||||
{"arg", "value"}
|
||||
}}
|
||||
{"arguments", args}
|
||||
}}
|
||||
},
|
||||
{
|
||||
|
|
@ -277,9 +375,7 @@ caps caps_get(jinja::program & prog) {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"arguments", {
|
||||
{"arg", "value"}
|
||||
}}
|
||||
{"arguments", args}
|
||||
}}
|
||||
}
|
||||
})}
|
||||
|
|
@ -328,7 +424,7 @@ caps caps_get(jinja::program & prog) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto & tool_calls = messages->at(1)->at("tool_calls");;
|
||||
auto & tool_calls = messages->at(1)->at("tool_calls");
|
||||
caps_print_stats(tool_calls, "messages[1].tool_calls");
|
||||
|
||||
// check for second tool call usage
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ struct caps {
|
|||
bool supports_string_content = true;
|
||||
bool supports_typed_content = false;
|
||||
|
||||
bool supports_object_arguments = false;
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@ struct value_array_t : public value_t {
|
|||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_array = std::shared_ptr<value_array_t>;
|
||||
|
|
@ -587,7 +587,7 @@ struct value_object_t : public value_t {
|
|||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
|
|
|||
|
|
@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
|||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
// forward declaration for use in clone
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget, common_reasoning_budget_state initial_state);
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init(
|
||||
return common_reasoning_budget_init_state(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
|
|
@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
|
|||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
initial_state = REASONING_BUDGET_FORCING;
|
||||
|
|
@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
}
|
||||
);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
|
||||
if (!prefill_tokens.empty() && !start_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() &&
|
||||
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
|
||||
initial_state = REASONING_BUDGET_COUNTING;
|
||||
// If the end sequence also follows the start in the prefill, reasoning
|
||||
// was opened and immediately closed — stay IDLE.
|
||||
if (!end_tokens.empty() &&
|
||||
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
|
||||
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
|
||||
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
|
||||
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
|
||||
initial_state = REASONING_BUDGET_IDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
|
|||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
|
||||
// note: COUNTING with budget <= 0 is promoted to FORCING
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (*it == '?') {
|
||||
if (it != end && *it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include "reasoning-budget.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
|
|
@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
const std::string & grammar_str = common_grammar_value(params.grammar);
|
||||
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
|
|
@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
if (!grammar_str.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size());
|
||||
} else {
|
||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feed generation prompt tokens to the grammar sampler so it advances past
|
||||
// tokens the template already placed in the prompt.
|
||||
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
|
||||
GGML_ASSERT(vocab != nullptr);
|
||||
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
|
||||
if (!prefill_tokens.empty()) {
|
||||
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
|
||||
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
|
||||
// Some tokenizers will add a space before the first special token, need to remove
|
||||
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
|
||||
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
|
|
@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
|
||||
prefill_tokens));
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
|
|
|
|||
|
|
@ -31,10 +31,10 @@ import gguf
|
|||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
|
|
@ -45,9 +45,9 @@ except ImportError:
|
|||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
TokenizerVersion: Any = None
|
||||
Tekkenizer: Any = None
|
||||
SentencePieceTokenizer: Any = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
|
|
@ -145,6 +145,7 @@ class ModelBase:
|
|||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self._is_nvfp4 = False
|
||||
self._is_mxfp4 = False
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
|
|
@ -220,7 +221,7 @@ class ModelBase:
|
|||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
|
||||
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
|
||||
part_names = sorted(part_dict.keys())
|
||||
else:
|
||||
weight_map = {}
|
||||
|
|
@ -272,8 +273,9 @@ class ModelBase:
|
|||
return tensors
|
||||
|
||||
def dequant_model(self):
|
||||
if self._is_nvfp4:
|
||||
return # NVFP4 weights are repacked in _generate_nvfp4_tensors
|
||||
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
|
||||
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
|
||||
return
|
||||
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
|
@ -297,11 +299,16 @@ class ModelBase:
|
|||
scale = scale.float()
|
||||
|
||||
if block_size is not None:
|
||||
dim_offset = scale.ndim - len(block_size)
|
||||
for i, size in enumerate(block_size):
|
||||
scale = scale.repeat_interleave(size, i)
|
||||
scale = scale.repeat_interleave(size, dim_offset + i)
|
||||
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
||||
scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
||||
|
||||
# align scale dims to weight for correct broadcasting (e.g. [128] -> [128, 1, 1])
|
||||
while scale.ndim < weight.ndim:
|
||||
scale = scale.unsqueeze(-1)
|
||||
|
||||
return weight.float() * scale
|
||||
|
||||
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
|
||||
|
|
@ -392,7 +399,7 @@ class ModelBase:
|
|||
elif quant_method == "fp8":
|
||||
block_size = quant_config.get("weight_block_size")
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
if name.endswith("_scale_inv"):
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
|
|
@ -400,6 +407,8 @@ class ModelBase:
|
|||
tensors_to_remove.append(name)
|
||||
if name.endswith(".activation_scale"): # unused
|
||||
tensors_to_remove.append(name)
|
||||
if name.endswith("_activation_scale"): # Mistral-Small-4-119B-2602, unused
|
||||
tensors_to_remove.append(name)
|
||||
# mistral format
|
||||
if name.endswith(".qscale_weight"):
|
||||
weight_name = name.removesuffix("qscale_weight") + "weight"
|
||||
|
|
@ -474,7 +483,20 @@ class ModelBase:
|
|||
tensors_to_remove.append(base_name + "_zero_point")
|
||||
else:
|
||||
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
|
||||
else:
|
||||
elif quant_method == "modelopt":
|
||||
# Mixed-precision ModelOpt models: NVFP4 tensors are handled by
|
||||
# _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
|
||||
# are dequantized here. input_scale tensors are unused.
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
|
||||
tensors_to_remove.append(name)
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method is not None:
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
|
||||
for name in tensors_to_remove:
|
||||
|
|
@ -520,12 +542,6 @@ class ModelBase:
|
|||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
|
||||
if self._is_nvfp4:
|
||||
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
|
||||
return []
|
||||
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
|
||||
return []
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
|
|
@ -609,6 +625,7 @@ class ModelBase:
|
|||
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
|
||||
expert_shapes: dict[tuple[int, str], list[int]] = {}
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
|
||||
consumed: list[str] = []
|
||||
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if not name.endswith(".weight"):
|
||||
|
|
@ -620,8 +637,18 @@ class ModelBase:
|
|||
# Force eager materialization of lazy tensors
|
||||
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
|
||||
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
|
||||
|
||||
# Skip non-NVFP4 tensors (e.g. FP8 with per-channel 1D scales)
|
||||
if scale.ndim < 2:
|
||||
continue
|
||||
|
||||
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
|
||||
|
||||
# Mark tensors for removal from model_tensors (already written to gguf)
|
||||
consumed.extend([name, scale_name])
|
||||
if scale2_name in self.model_tensors:
|
||||
consumed.append(scale2_name)
|
||||
|
||||
# Check if this is a per-expert tensor
|
||||
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
|
||||
if m:
|
||||
|
|
@ -652,6 +679,15 @@ class ModelBase:
|
|||
for (bid, proj_type) in list(expert_blocks.keys()):
|
||||
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
||||
|
||||
# Remove consumed tensors so get_tensors/modify_tensors won't see them
|
||||
for name in consumed:
|
||||
self.model_tensors.pop(name, None)
|
||||
|
||||
# Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
del self.model_tensors[name]
|
||||
|
||||
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
|
||||
experts = expert_blocks.pop(key)
|
||||
scales = expert_scales.pop(key)
|
||||
|
|
@ -677,20 +713,33 @@ class ModelBase:
|
|||
def prepare_tensors(self):
|
||||
# detect NVFP4 quantization (ModelOpt format)
|
||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
|
||||
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
|
||||
quant_config_file = self.dir_model / "hf_quant_config.json"
|
||||
|
||||
if not quant_algo and quant_config_file.is_file():
|
||||
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
|
||||
with open(quant_config_file, "r", encoding="utf-8") as f:
|
||||
quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
|
||||
quant_config = json.load(f).get("quantization") or {}
|
||||
quant_algo = quant_config.get("quant_algo", quant_algo)
|
||||
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
|
||||
|
||||
# Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
|
||||
# per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
|
||||
if quant_algo != "NVFP4":
|
||||
if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
|
||||
quant_algo = "NVFP4"
|
||||
|
||||
self._is_nvfp4 = quant_algo == "NVFP4"
|
||||
self._is_mxfp4 = quant_method == "mxfp4"
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# NVFP4 weights are repacked and written directly to gguf_writer
|
||||
# NVFP4 weights are repacked and written directly to gguf_writer.
|
||||
# This must run before dequant_model so NVFP4 tensors are removed
|
||||
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
|
||||
if self._is_nvfp4:
|
||||
self._generate_nvfp4_tensors()
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
|
||||
if self.tensor_map.mapping:
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
|
@ -830,6 +879,12 @@ class ModelBase:
|
|||
if self.metadata.name is None:
|
||||
self.metadata.name = self.dir_model.name
|
||||
|
||||
if self.ftype in (gguf.LlamaFileType.ALL_F32, gguf.LlamaFileType.MOSTLY_F16, gguf.LlamaFileType.MOSTLY_BF16):
|
||||
if self._is_nvfp4:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_NVFP4
|
||||
elif self._is_mxfp4:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
|
||||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.size_label is None and total_params > 0:
|
||||
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
|
||||
|
|
@ -1016,6 +1071,10 @@ class TextModel(ModelBase):
|
|||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
||||
|
||||
if self.hparams.get("is_causal") is False:
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
logger.info("gguf: causal attention = False")
|
||||
|
||||
# TODO: Handle "sliding_attention" similarly when models start implementing it
|
||||
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
|
||||
if (rope_type := rope_params.get("rope_type")) is not None:
|
||||
|
|
@ -2992,10 +3051,16 @@ class LlavaVisionModel(MmprojModel):
|
|||
def get_token_id(self, token: str) -> int:
|
||||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||
added_tokens_decoder = json.load(f)['added_tokens_decoder']
|
||||
added_tokens_decoder = json.load(f).get('added_tokens_decoder') or {}
|
||||
for id_, token_data in added_tokens_decoder.items():
|
||||
if token_data["content"] == token:
|
||||
if token_data.get("content") == token:
|
||||
return int(id_)
|
||||
# fallthrough to tokenizer.json
|
||||
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
for token_data in tokenizer_json["added_tokens"]:
|
||||
if token_data["content"] == token:
|
||||
return int(token_data["id"])
|
||||
raise ValueError(f"Token '{token}' not found in tokenizer config.")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
|
@ -3159,40 +3224,6 @@ class Llama4VisionModel(MmprojModel):
|
|||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"Mistral3ForConditionalGeneration",
|
||||
"Ministral3ForCausalLM",
|
||||
)
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone has migrated to newer version of llama.cpp
|
||||
if self.hparams.get("model_type") != "ministral3":
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
rope_params = self.rope_parameters
|
||||
if self.hparams.get("model_type") == "ministral3":
|
||||
assert rope_params, "ministral3 must have 'rope_parameters' config"
|
||||
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
|
||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("DeciLMForCausalLM")
|
||||
class DeciModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DECI
|
||||
|
|
@ -5860,7 +5891,7 @@ class InternLM2Model(TextModel):
|
|||
logger.error(f'Error: Missing {tokenizer_path}')
|
||||
sys.exit(1)
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
|
||||
|
|
@ -6181,7 +6212,7 @@ class BertModel(TextModel):
|
|||
|
||||
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
|
||||
else:
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
|
|
@ -8232,6 +8263,8 @@ class DeepseekV2Model(TextModel):
|
|||
# TODO @ngxson : remove this when we support MTP for deepseek models
|
||||
skip_mtp = True
|
||||
|
||||
merge_expert = True
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_gpt2()
|
||||
|
|
@ -8370,7 +8403,7 @@ class DeepseekV2Model(TextModel):
|
|||
return
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
if self.merge_expert and name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
assert bid is not None
|
||||
|
||||
|
|
@ -8429,6 +8462,69 @@ class DeepseekV2Model(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"Mistral3ForConditionalGeneration",
|
||||
"Ministral3ForCausalLM",
|
||||
)
|
||||
class Mistral3Model(TextModel):
|
||||
class Ministral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
rope_params = self.rope_parameters
|
||||
if self.hparams.get("model_type") == "ministral3":
|
||||
assert rope_params, "ministral3 must have 'rope_parameters' config"
|
||||
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
|
||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
class Mistral4Model(DeepseekV2Model):
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL4
|
||||
skip_mtp = False # model contains no MTP layers, so no need to skip
|
||||
merge_expert = False # experts are already stacked as 3D
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
if name.endswith(".down_proj") or name.endswith(".gate_up_proj"):
|
||||
name = name + ".weight"
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused
|
||||
impl: TextModel
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams.get("model_type") == "mistral4":
|
||||
self.impl = Mistral3Model.Mistral4Model(*args, **kwargs)
|
||||
else:
|
||||
self.impl = Mistral3Model.Ministral3Model(*args, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
self.impl.set_vocab()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.impl.set_gguf_parameters()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
yield from self.impl.modify_tensors(data_torch, name, bid)
|
||||
|
||||
def prepare_tensors(self):
|
||||
self.impl.prepare_tensors()
|
||||
|
||||
def write_vocab(self):
|
||||
self.impl.write_vocab()
|
||||
|
||||
def write(self):
|
||||
self.impl.write()
|
||||
|
||||
|
||||
@ModelBase.register("MiniMaxM2ForCausalLM")
|
||||
class MiniMaxM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
|
|
@ -8793,7 +8889,7 @@ class T5Model(TextModel):
|
|||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
|
|
@ -8930,7 +9026,7 @@ class T5EncoderModel(TextModel):
|
|||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
|
|
@ -11038,8 +11134,7 @@ class GptOssModel(TextModel):
|
|||
|
||||
# TODO: remove once MXFP4 is supported more generally
|
||||
def dequant_model(self):
|
||||
quant_config = self.hparams.get("quantization_config")
|
||||
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||
if self._is_mxfp4:
|
||||
return
|
||||
return super().dequant_model()
|
||||
|
||||
|
|
@ -12192,6 +12287,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||
kwargs = {}
|
||||
|
||||
if func is torch.Tensor.numpy:
|
||||
assert len(args)
|
||||
return args[0].numpy()
|
||||
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -112,11 +112,11 @@ class Tensor:
|
|||
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
|
||||
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
|
||||
assert name_len < 4096, 'Absurd tensor name length'
|
||||
quant = gguf.GGML_QUANT_SIZES.get(dtype)
|
||||
self.dtype = gguf.GGMLQuantizationType(dtype)
|
||||
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
|
||||
assert quant is not None, 'Unknown tensor type'
|
||||
(blksize, tysize) = quant
|
||||
offset += 12
|
||||
self.dtype= gguf.GGMLQuantizationType(dtype)
|
||||
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
|
||||
offset += 4 * n_dims
|
||||
self.name = bytes(data[offset:offset + name_len])
|
||||
|
|
|
|||
|
|
@ -199,10 +199,13 @@ class LoraTorchTensor:
|
|||
kwargs = {}
|
||||
|
||||
if func is torch.permute:
|
||||
assert len(args)
|
||||
return type(args[0]).permute(*args, **kwargs)
|
||||
elif func is torch.reshape:
|
||||
assert len(args)
|
||||
return type(args[0]).reshape(*args, **kwargs)
|
||||
elif func is torch.stack:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
|
|
@ -211,6 +214,7 @@ class LoraTorchTensor:
|
|||
torch.stack([b._lora_B for b in args[0]], dim),
|
||||
)
|
||||
elif func is torch.cat:
|
||||
assert len(args)
|
||||
assert isinstance(args[0], Sequence)
|
||||
dim = kwargs.get("dim", 0)
|
||||
assert dim == 0
|
||||
|
|
@ -362,7 +366,7 @@ if __name__ == '__main__':
|
|||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
class LoraModel(model_class):
|
||||
class LoraModel(model_class): # ty: ignore[unsupported-base]
|
||||
model_arch = model_class.model_arch
|
||||
|
||||
lora_alpha: float
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
|
|||
**Analysis + Parser Building in Two Steps**:
|
||||
|
||||
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
|
||||
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
|
||||
2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
|
||||
|
||||
## Data Structures
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
|
|||
|
||||
### `analyze_tools` and its sub-structs
|
||||
|
||||
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
|
||||
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
|
||||
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
|
||||
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
|
||||
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
|
||||
|
|
@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
|
|||
| Value | Description |
|
||||
|-----------------|-----------------------------------------------------------------------------------|
|
||||
| `NONE` | No reasoning markers detected |
|
||||
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
|
||||
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
|
||||
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
|
||||
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
|
||||
| `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
|
||||
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
|
||||
|
||||
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
|
||||
|
||||
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
|
||||
|
||||
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
|
||||
|
||||
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
|
||||
|
||||
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
|
||||
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
|
||||
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
|
||||
|
||||
**`content_mode`**: How the template wraps assistant content.
|
||||
|
||||
| Value | Description |
|
||||
|
|
@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
|
|||
|
||||
- Searches `diff.right` (output with reasoning) for the reasoning content needle
|
||||
- Uses PEG parsers to find surrounding markers:
|
||||
- If both pre/post markers found in `diff.right` → `TAG_BASED` (both tags visible in diff = no forced close)
|
||||
- If both found but post marker only in the full output B → `FORCED_CLOSED`
|
||||
- If only post marker found → `DELIMITER`
|
||||
- If both pre/post markers found in `diff.right` → `TAG_BASED`
|
||||
- If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
|
||||
- If only post marker found → `TAG_BASED` (delimiter-style, empty start)
|
||||
- Sets `reasoning.start` and `reasoning.end`
|
||||
|
||||
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
|
||||
|
||||
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
|
||||
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
|
||||
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
|
||||
- Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker → sets `reasoning.start`, mode = `TAG_BASED`
|
||||
- Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
|
||||
- The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
|
||||
|
||||
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
|
||||
|
||||
|
|
@ -343,7 +352,7 @@ Classification logic:
|
|||
|
||||
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
|
||||
|
||||
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
|
||||
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
|
||||
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
|
||||
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
|
||||
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
|
||||
|
|
@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
|
|||
|
||||
#### Reasoning Parser (`analyze_reasoning::build_parser`)
|
||||
|
||||
| Mode | Parser |
|
||||
|-----------------------------------|---------------------------------------------------------------------|
|
||||
| Not extracting reasoning | `eps()` |
|
||||
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
|
||||
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
|
||||
| Mode | Parser |
|
||||
|-----------------------------------------------|---------------------------------------------------------------------------|
|
||||
| Not extracting reasoning | `eps()` |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
|
||||
|
||||
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
|
||||
|
||||
#### Content Parser (`analyze_content::build_parser`)
|
||||
|
||||
|
|
@ -410,9 +420,7 @@ All three tool parsers return:
|
|||
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
|
||||
```
|
||||
|
||||
### Python Dict Format
|
||||
|
||||
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
|
||||
Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
|
||||
|
||||
## Mapper
|
||||
|
||||
|
|
@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
|
|||
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
|
||||
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
|
||||
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
|
||||
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
|
||||
- **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
|
||||
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|-------------------------------------------|----------------------------------------------------------------------|
|
||||
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
|
||||
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
|
||||
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
|
||||
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
|
||||
| | `compare_variants()`, string helpers |
|
||||
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
|
||||
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
|
||||
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
|
||||
| `tools/parser/template-analysis.cpp` | Template analysis tool |
|
||||
| File | Purpose |
|
||||
|-------------------------------------------|---------------------------------------------------------------------------------|
|
||||
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
|
||||
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
|
||||
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
|
||||
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
|
||||
| | `wrap_for_generation_prompt()`, string helpers |
|
||||
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
|
||||
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
|
||||
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
|
||||
| `tools/parser/template-analysis.cpp` | Template analysis tool |
|
||||
|
||||
## Testing & Debugging
|
||||
|
||||
|
|
@ -516,10 +524,10 @@ To support a new template format:
|
|||
|
||||
## Edge Cases and Quirks
|
||||
|
||||
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
|
||||
1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
|
||||
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
|
||||
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
|
||||
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
|
||||
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
|
||||
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
|
||||
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
|
||||
3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
|
||||
4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
|
||||
5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
|
||||
6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
|
||||
7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ Additionally, there the following images, similar to the above:
|
|||
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
|
||||
|
||||
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).
|
||||
|
||||
|
|
|
|||
74
docs/ops.md
74
docs/ops.md
|
|
@ -12,9 +12,9 @@ Legend:
|
|||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|
||||
| Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -23,63 +23,63 @@ Legend:
|
|||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
@ -91,31 +91,31 @@ Legend:
|
|||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
32655
docs/ops/Metal.csv
32655
docs/ops/Metal.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -5937,6 +5937,20 @@
|
|||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=0","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=1","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=10.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=10.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=0","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=1","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=10.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=10.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=10.000000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=0","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=1","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
|
||||
"SYCL0","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL"
|
||||
"SYCL0","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL"
|
||||
|
|
@ -10209,24 +10223,24 @@
|
|||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=nearest","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=0","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=1","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=0","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=0","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=1","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|antialias","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear|antialias","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic|align_corners","support","0","no","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=0","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=1","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=0","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=0","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=1","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|antialias","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear|antialias","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic|align_corners","support","1","yes","SYCL"
|
||||
"SYCL0","SUM","type=f32,ne=[10,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","SUM","type=f32,ne=[11,5,6,3],permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","SUM","type=f32,ne=[11,5,6,3],permute=[0,3,2,1]","support","0","no","SYCL"
|
||||
|
|
@ -13325,6 +13339,262 @@
|
|||
"SYCL0","FLASH_ATTN_EXT","hsk=256,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=256,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=256,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=113,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[1,1],kv=1024,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[4,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=113,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[1,1],kv=1024,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=320,hsv=256,nh=4,nr23=[4,1],kv=512,nb=75,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=576,hsv=512,nh=1,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=576,hsv=512,nh=1,nr23=[1,1],kv=113,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
"SYCL0","FLASH_ATTN_EXT","hsk=576,hsv=512,nh=1,nr23=[1,1],kv=113,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
8756
docs/ops/WebGPU.csv
8756
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
|||
return f'({result})?' if min_items == 0 else result
|
||||
|
||||
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
|
||||
has_min = min_value != None
|
||||
has_max = max_value != None
|
||||
|
||||
def digit_range(from_char: str, to_char: str):
|
||||
out.append("[")
|
||||
if from_char == to_char:
|
||||
|
|
@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
|||
out.append(to_str[i])
|
||||
out.append("]")
|
||||
|
||||
if has_min and has_max:
|
||||
if min_value is not None and max_value is not None:
|
||||
if min_value < 0 and max_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
|
||||
|
|
@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
|||
|
||||
less_decimals = max(decimals_left - 1, 1)
|
||||
|
||||
if has_min:
|
||||
if min_value is not None:
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
|
||||
|
|
@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
|||
more_digits(length - 1, less_decimals)
|
||||
return
|
||||
|
||||
if has_max:
|
||||
if max_value is not None:
|
||||
if max_value >= 0:
|
||||
if top_level:
|
||||
out.append("\"-\" [1-9] ")
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
|
|||
print("Using SentenceTransformer to apply all numbered layers")
|
||||
model = SentenceTransformer(model_path)
|
||||
tokenizer = model.tokenizer
|
||||
config = model[0].auto_model.config # type: ignore
|
||||
config = model[0].auto_model.config
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
|
@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
|
|||
print(f"Model file: {type(model).__module__}")
|
||||
|
||||
# Verify the model is using the correct sliding window
|
||||
if hasattr(model.config, 'sliding_window'): # type: ignore
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
|
||||
if hasattr(model.config, 'sliding_window'):
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}")
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ def main():
|
|||
device = next(model.parameters()).device
|
||||
else:
|
||||
# For SentenceTransformer, get device from the underlying model
|
||||
device = next(model[0].auto_model.parameters()).device # type: ignore
|
||||
device = next(model[0].auto_model.parameters()).device
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
|
||||
|
|
@ -177,7 +177,7 @@ def main():
|
|||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
|
||||
else:
|
||||
# Standard approach: use base model output only
|
||||
encoded = tokenizer(
|
||||
|
|
@ -205,12 +205,12 @@ def main():
|
|||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
|
||||
if len(all_embeddings.shape) == 1:
|
||||
n_embd = all_embeddings.shape[0] # type: ignore
|
||||
n_embd = all_embeddings.shape[0]
|
||||
n_embd_count = 1
|
||||
all_embeddings = all_embeddings.reshape(1, -1)
|
||||
else:
|
||||
n_embd = all_embeddings.shape[1] # type: ignore
|
||||
n_embd_count = all_embeddings.shape[0] # type: ignore
|
||||
n_embd = all_embeddings.shape[1]
|
||||
n_embd_count = all_embeddings.shape[0]
|
||||
|
||||
print()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import argparse
|
||||
import sys
|
||||
from common import compare_tokens # type: ignore
|
||||
from common import compare_tokens # type: ignore[import-not-found]
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import re
|
|||
from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
|
|
@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
|||
|
||||
# Assert that the parameter has a type annotation
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation")
|
||||
raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
|
||||
|
||||
# Find the parameter's description in the docstring
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
|
|
@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
|||
# Assert that the parameter has a description
|
||||
if not param_doc or not param_doc.description:
|
||||
raise ValueError(
|
||||
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
|
||||
f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
|
||||
|
||||
# Add parameter details to the schema
|
||||
param_docs.append((param.name, param_doc))
|
||||
|
|
@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
|||
dynamic_fields[param.name] = (
|
||||
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
|
||||
# Creating the dynamic model
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
|
||||
dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
|
||||
|
||||
for name, param_doc in param_docs:
|
||||
dynamic_model.model_fields[name].description = param_doc.description
|
||||
|
|
@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
if items != {}:
|
||||
array = {"properties": items}
|
||||
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
|
||||
fields[field_name] = (List[array_type], ...)
|
||||
fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
|
||||
else:
|
||||
fields[field_name] = (list, ...)
|
||||
elif field_type == "object":
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
|||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 7)
|
||||
set(GGML_VERSION_PATCH 8)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
|
|
|||
|
|
@ -733,6 +733,10 @@ extern "C" {
|
|||
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||
"use ggml_row_size() instead");
|
||||
|
||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
|||
bli_thread_set_num_threads(ctx->n_threads);
|
||||
#elif defined(GGML_BLAS_USE_NVPL)
|
||||
nvpl_blas_set_num_threads(ctx->n_threads);
|
||||
#elif defined(GGML_BLAS_USE_MKL)
|
||||
mkl_set_num_threads(ctx->n_threads);
|
||||
#endif
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
|
|
|
|||
|
|
@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
|
|||
end = 2 * ((n_head - 1) - n_head_log2) + 1;
|
||||
step = 2;
|
||||
count = n_head - n_head_log2;
|
||||
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,
|
||||
dtype);
|
||||
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
|
||||
step, dtype);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
ggml_tensor * src0 = dst->src[0]; // src
|
||||
ggml_tensor * src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|
||||
|| dst->type == GGML_TYPE_BF16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if (src0->type == dst->type) {
|
||||
|
|
@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
break;
|
||||
}
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
|
||||
|
|
@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
|
||||
dst->type);
|
||||
|
|
@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
|
|||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
|
|
@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
ggml_cann_mat_mul_fp(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
|
@ -2943,6 +2949,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
// Rotate full tensor (no tail), using trans tensors
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
|
||||
} else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
|
||||
// In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
|
||||
// read and write the same non-contiguous buffer. Use contiguous temporaries.
|
||||
size_t contiguous_nb[GGML_MAX_DIMS];
|
||||
contiguous_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
int64_t total_elements = ggml_nelements(src0);
|
||||
ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
|
||||
ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
|
||||
|
||||
acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
|
||||
src0->ne, contiguous_nb, GGML_MAX_DIMS);
|
||||
acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
|
||||
dst->ne, contiguous_nb, GGML_MAX_DIMS);
|
||||
|
||||
cann_copy(ctx, acl_src.get(), acl_src_contig.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
|
||||
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
|
||||
cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
|
||||
} else {
|
||||
// Rotate full tensor (no tail), using original tensors
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
|
||||
|
|
@ -3599,6 +3626,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
|
|||
acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
|
||||
acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
// Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
|
||||
// (required by FusedInferAttentionScoreV2)
|
||||
const int64_t D = src0->ne[0];
|
||||
const int64_t D_padded = GGML_PAD(D, 16);
|
||||
const bool needs_padding = (D != D_padded);
|
||||
|
||||
ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
|
||||
ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
|
||||
ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
|
||||
|
||||
if (needs_padding) {
|
||||
int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
|
||||
|
||||
auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
|
||||
ggml_cann_pool_alloc & allocator) {
|
||||
int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
|
||||
size_t pad_nb[GGML_MAX_DIMS];
|
||||
pad_nb[0] = faElemSize;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
|
||||
}
|
||||
int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
|
||||
void * buffer = allocator.alloc(nelements * faElemSize);
|
||||
acl_tensor_ptr padded =
|
||||
ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
|
||||
aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
|
||||
tensor = std::move(padded);
|
||||
};
|
||||
|
||||
pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
|
||||
pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
|
||||
pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
|
||||
|
||||
src0_bsnd_ne[0] = D_padded;
|
||||
src1_bsnd_ne[0] = D_padded;
|
||||
src2_bsnd_ne[0] = D_padded;
|
||||
}
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
acl_tensor_ptr bcast_pse_tensor;
|
||||
|
|
@ -3688,17 +3753,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
|
|||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
acl_tensor_ptr fa_dst_tensor;
|
||||
acl_tensor_ptr acl_dst_tensor;
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32 || needs_padding) {
|
||||
int64_t * out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
|
||||
void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
|
||||
|
||||
fa_dst_tensor =
|
||||
ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
|
||||
|
|
@ -3730,8 +3794,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
|
|||
nullptr // softmaxLse
|
||||
);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
// Step 6: post-processing — slice padded output and/or cast to f32
|
||||
if (needs_padding) {
|
||||
ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
|
||||
size_t sliced_nb[GGML_MAX_DIMS];
|
||||
sliced_nb[0] = faElemSize;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
|
||||
}
|
||||
int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
|
||||
void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
|
||||
acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
|
||||
sliced_ne, sliced_nb, GGML_MAX_DIMS);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
|
||||
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
|
||||
|
||||
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
} else {
|
||||
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
|
||||
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
if (weight_to_nz && tensor->type != GGML_TYPE_BF16
|
||||
&& is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
weight_format_to_nz(tensor, offset, ctx->device);
|
||||
|
|
@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
|
|||
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
||||
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
||||
}
|
||||
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
} else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
|
||||
&& is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
// NZ format weight are not support quantized yet.
|
||||
// If ND tensor transform to NZ, size may changed.
|
||||
int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
|
||||
|
|
@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
|
|
@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
|
|
@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_tensor * src = op->src[0];
|
||||
#ifdef ASCEND_310P
|
||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
||||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
|
||||
// only support F32 and F16.
|
||||
// only support F32 and F16 on 310P.
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
|
||||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
|
||||
// only support F32, F16 and BF16.
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
// TODO: support GGML_TYPE_BF16
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
#ifndef ASCEND_310P
|
||||
case GGML_TYPE_BF16:
|
||||
#endif
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -2503,10 +2524,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] % 16 != 0) {
|
||||
// TODO: padding to support
|
||||
return false;
|
||||
}
|
||||
float logitSoftcap = 0.0f;
|
||||
memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
|
||||
if (logitSoftcap != 0.0f) {
|
||||
|
|
|
|||
|
|
@ -570,24 +570,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
set(KLEIDIAI_FETCH_ARGS
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
|
||||
)
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
|
||||
endif()
|
||||
|
||||
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
|
||||
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
${KLEIDIAI_FETCH_ARGS}
|
||||
EXCLUDE_FROM_ALL
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
else()
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
${KLEIDIAI_FETCH_ARGS}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED
|
||||
)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
|
|
|||
|
|
@ -666,7 +666,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
|
||||
float sumf = 0;
|
||||
|
||||
#if defined __ARM_NEON
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
|
||||
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
float32x4_t acc = vdupq_n_f32(0.0f);
|
||||
|
|
|
|||
|
|
@ -115,10 +115,10 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|||
|
||||
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_q8_K * y_blocks = (block_q8_K *)y;
|
||||
size_t nb = k / QK_K;
|
||||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
block_q8_K * y_blocks = (block_q8_K *)y;
|
||||
const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8();
|
||||
|
||||
for (size_t i = 0; i < nb; i++) {
|
||||
|
|
@ -2052,6 +2052,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
|
@ -2147,6 +2148,7 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2163,6 +2165,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
|
@ -2269,6 +2272,7 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2285,6 +2289,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static const uint8_t sign_gather_indices_arr[64] = {
|
||||
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
|
||||
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
|
||||
|
|
@ -2488,6 +2493,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2507,7 +2513,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined(__riscv_v)
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
static const int8_t keven_signs_q2xs[1024] = {
|
||||
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||
|
|
@ -2542,7 +2548,6 @@ static const int8_t keven_signs_q2xs[1024] = {
|
|||
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
||||
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
};
|
||||
#endif
|
||||
|
||||
static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
|
|
@ -2618,6 +2623,7 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_
|
|||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2634,6 +2640,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
|
@ -2818,6 +2825,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size
|
|||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2830,10 +2838,11 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
|
||||
ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc);
|
||||
|
|
@ -2928,6 +2937,7 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
}
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -2944,6 +2954,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
|
@ -3036,6 +3047,7 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size
|
|||
}
|
||||
*s = 0.25f * sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3052,6 +3064,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
|
|
@ -3161,6 +3174,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3177,6 +3191,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
|
|
@ -3190,7 +3205,6 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_
|
|||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);
|
||||
float sumf = 0;
|
||||
int acc[4];
|
||||
|
|
@ -3252,14 +3266,8 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_
|
|||
}
|
||||
|
||||
*s = sumf;
|
||||
|
||||
#else
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(nb);
|
||||
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3276,6 +3284,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
|
|
@ -3381,6 +3390,7 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3397,6 +3407,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
|
@ -3467,6 +3478,7 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3483,6 +3495,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
|
|
@ -3592,6 +3605,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
*s = sumf;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
|
|
@ -3604,6 +3618,6 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
break;
|
||||
}
|
||||
#else
|
||||
return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -107,8 +107,7 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
}
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
|
||||
ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -203,6 +202,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -222,7 +222,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
|
||||
|
|
@ -256,9 +255,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
|
||||
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -280,7 +276,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
|
@ -392,9 +387,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
|
||||
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -416,7 +408,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
|
@ -451,9 +442,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
|
||||
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -476,7 +464,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(blocklen);
|
||||
UNUSED(bs);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
|
||||
|
|
@ -505,9 +492,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
|
||||
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -679,9 +663,9 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
|
||||
} // End K-Block
|
||||
__riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl);
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
|
|
@ -909,6 +893,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -929,7 +914,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
|
@ -994,9 +978,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -1019,7 +1000,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
|
@ -1267,9 +1247,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -1292,7 +1269,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
|
|
@ -1355,9 +1331,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -1380,7 +1353,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v_intrinsic
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
|
@ -1429,9 +1401,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -1731,3 +1700,4 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -531,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
|||
|
||||
UNUSED(bs);
|
||||
|
||||
__m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
||||
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
||||
|
||||
// Permute mask used for easier vector processing at later stages
|
||||
|
|
@ -580,6 +579,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
|||
if constexpr (
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
||||
col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
// Load 8 E8M0 exponents and convert to float via LUT
|
||||
|
|
|
|||
|
|
@ -1461,7 +1461,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|||
return false;
|
||||
}
|
||||
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
||||
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
||||
ggml_ne(op->src[1], 3) == 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
} else {
|
||||
if (op->src[0]->type != GGML_TYPE_F16) {
|
||||
return nullptr;
|
||||
}
|
||||
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
||||
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
||||
const bool has_kernel = slot_total > 0;
|
||||
if (has_kernel && op->src[1]->ne[1] > 1) {
|
||||
if (slot_total > 0 && op->src[1]->ne[1] > 1) {
|
||||
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
||||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
|
|||
|
||||
private:
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
|
|
@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
|
|||
}
|
||||
}
|
||||
|
||||
__attribute__((always_inline))
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
|
|
|
|||
|
|
@ -6205,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
|
|||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
|
@ -6236,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
|
|||
int ofs1 = is_2D ? nb12 : nb11;
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
||||
{
|
||||
|
|
@ -6249,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
|
|||
|
||||
// micro kernel
|
||||
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
||||
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
|
||||
const float * const src_data_f32 = src1->type == GGML_TYPE_F32
|
||||
? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
||||
: nullptr; // [IH, IW]
|
||||
const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
|
||||
? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
||||
: nullptr; // [IH, IW]
|
||||
|
||||
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
||||
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
||||
|
|
@ -6259,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
|
|||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
||||
if (src_data_f32 != nullptr) {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
|
||||
} else {
|
||||
dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1365,6 +1365,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
|
|||
}
|
||||
}
|
||||
|
||||
// Only enable these for RISC-V.
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
|
|
@ -1568,6 +1569,7 @@ void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
assert(nc % 16 == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
const int nb = n / QK_K;
|
||||
const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx;
|
||||
|
|
@ -2381,6 +2383,7 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n,
|
|||
}
|
||||
}
|
||||
|
||||
// Only enable these for RISC-V.
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||
if (err == hipSuccess) {
|
||||
// hipMemAdviseSetCoarseGrain is an optional performance hint;
|
||||
// ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
|
||||
cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
|
||||
(void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
|
||||
(void)hipGetLastError(); // clear any error
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ static int opt_verbose = 0;
|
|||
static int opt_profile = 0;
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
static int opt_experimental = 0;
|
||||
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
|
||||
|
||||
// Enable all stages by default
|
||||
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
|
||||
|
|
@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
|||
// Start the DSP-side service. We need to pass the queue ID to the
|
||||
// DSP in a FastRPC call; the DSP side will import the queue and start
|
||||
// listening for packets in a callback.
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
|
||||
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
|
||||
if (err != 0) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
|
||||
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
|
||||
|
|
@ -2362,6 +2363,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs,
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
// CONT is just a contiguous copy — reuse CPY op
|
||||
req->op = HTP_OP_CPY;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_REPEAT;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
|
|
@ -2449,12 +2471,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
|||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
switch (ggml_get_unary_op(t)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
supported = true;
|
||||
} else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
req->op = HTP_OP_UNARY_GELU;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
req->op = HTP_OP_UNARY_SIGMOID;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_NEG:
|
||||
req->op = HTP_OP_UNARY_NEG;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
req->op = HTP_OP_UNARY_EXP;
|
||||
supported = true;
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
req->op = HTP_OP_UNARY_SOFTPLUS;
|
||||
supported = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
|
|
@ -2640,16 +2683,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
|
||||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
switch (ggml_get_glu_op(node)) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
|
@ -2676,6 +2731,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_CONT:
|
||||
ggml_hexagon_dispatch_op<init_cont_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_hexagon_dispatch_op<init_repeat_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
|
|
@ -3006,6 +3069,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess,
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(sess);
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
|
||||
// CONT is same-type only, supports f32 and f16
|
||||
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
GGML_UNUSED(sess);
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Support f32 and f16
|
||||
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
|
||||
|
||||
// src and dst must be the same type
|
||||
if (src0->type != dst->type) return false;
|
||||
|
||||
// dst dims must be multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0) return false;
|
||||
if (dst->ne[1] % src0->ne[1] != 0) return false;
|
||||
if (dst->ne[2] % src0->ne[2] != 0) return false;
|
||||
if (dst->ne[3] % src0->ne[3] != 0) return false;
|
||||
|
||||
// require contiguous tensors (no transposition)
|
||||
if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
|
||||
|
||||
|
|
@ -3063,21 +3159,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
{
|
||||
const auto unary_op = ggml_get_unary_op(op);
|
||||
if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
const auto glu_op = ggml_get_glu_op(op);
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||
switch (ggml_get_glu_op(op)) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supp = ggml_hexagon_supported_rope(sess, op);
|
||||
break;
|
||||
|
|
@ -3098,6 +3205,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CONT:
|
||||
supp = ggml_hexagon_supported_cont(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_REPEAT:
|
||||
supp = ggml_hexagon_supported_repeat(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
|
@ -3258,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
|||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
|
||||
|
|
@ -3267,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
|||
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
|
||||
opt_opsync = str_opsync ? atoi(str_opsync) : 0;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED
|
|||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
repeat-ops.c
|
||||
argsort-ops.c
|
||||
ssm-conv.c
|
||||
)
|
||||
|
|
@ -39,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
|||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-matmul-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
|
|||
return q->capacity;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
|
||||
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
|
||||
// transparently handles values that exceed the 16-bit limit and submits
|
||||
// chained DMA transtions.
|
||||
//
|
||||
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
|
||||
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
|
||||
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
|
||||
// single descriptor, preserving async DMA behavior.
|
||||
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
|
||||
// one at a time. The first N-1 rows are pushed+popped synchronously;
|
||||
// the last row is left async so the caller can pop it.
|
||||
// ---------------------------------------------------------------------------
|
||||
#define UDMA_MAX_FIELD_VAL 65535u
|
||||
|
||||
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
|
||||
// Fast path: everything fits in 16 bits.
|
||||
if (__builtin_expect(
|
||||
width <= UDMA_MAX_FIELD_VAL &&
|
||||
nrows <= UDMA_MAX_FIELD_VAL &&
|
||||
src_stride <= UDMA_MAX_FIELD_VAL &&
|
||||
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
|
||||
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
|
||||
}
|
||||
|
||||
// Case 2: contiguous block (width == src_stride == dst_stride).
|
||||
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
|
||||
if (width == src_stride && width == dst_stride) {
|
||||
size_t total = width * nrows;
|
||||
|
||||
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
|
||||
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
|
||||
while (sub_width > 0 && total % sub_width != 0) {
|
||||
sub_width -= 128;
|
||||
}
|
||||
if (sub_width == 0) {
|
||||
// Fallback: use original width (must fit) with adjusted nrows.
|
||||
// This shouldn't happen for 128-aligned DMA sizes.
|
||||
sub_width = width;
|
||||
}
|
||||
size_t sub_nrows = total / sub_width;
|
||||
|
||||
// Handle sub_nrows > 65535 by issuing chunked descriptors.
|
||||
const uint8_t *src = (const uint8_t *)dptr.src;
|
||||
uint8_t *dst = (uint8_t *)dptr.dst;
|
||||
size_t rows_done = 0;
|
||||
while (rows_done < sub_nrows) {
|
||||
size_t chunk = sub_nrows - rows_done;
|
||||
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
|
||||
|
||||
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
|
||||
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
|
||||
return false;
|
||||
|
||||
rows_done += chunk;
|
||||
// Complete all chunks without waiting except the last one, so the
|
||||
// caller's single dma_queue_pop drains the final descriptor.
|
||||
if (rows_done < sub_nrows)
|
||||
dma_queue_pop_nowait(q);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Case 3: stride overflow — fall back to row-by-row.
|
||||
{
|
||||
const uint8_t *src = (const uint8_t *)dptr.src;
|
||||
uint8_t *dst = (uint8_t *)dptr.dst;
|
||||
for (size_t r = 0; r < nrows; ++r) {
|
||||
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
|
||||
src + r * src_stride);
|
||||
if (!dma_queue_push(q, p, 0, 0, width, 1))
|
||||
return false;
|
||||
if (r + 1 < nrows)
|
||||
dma_queue_pop_nowait(q);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
|
|||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
|
|
@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
|||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,72 @@
|
|||
// HMX operation entry-point declarations.
|
||||
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_OPS_H
|
||||
#define HMX_OPS_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef restrict
|
||||
# define restrict __restrict
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct htp_context; // forward declaration
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
const __fp16 *permuted_weight;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int act_stride;
|
||||
int weight_stride;
|
||||
int dst_stride;
|
||||
int ne02;
|
||||
int ne03;
|
||||
int ne12;
|
||||
int ne13;
|
||||
size_t src0_nb2;
|
||||
size_t src0_nb3;
|
||||
size_t src1_nb2;
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_w16a32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
|
||||
const hmx_matmul_w16a32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
|
||||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HMX_OPS_H
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
// Conditional fine-grained profiling macros for HMX operations.
|
||||
//
|
||||
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
|
||||
// header) to instrument sub-operation latencies with HAP qtimer. When the
|
||||
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
|
||||
// overhead.
|
||||
//
|
||||
// Usage:
|
||||
// TIMER_DEFINE(my_phase); // declare accumulator variable
|
||||
// TIMER_START(my_phase); // snapshot start time
|
||||
// ... work ...
|
||||
// TIMER_STOP(my_phase); // accumulate elapsed ticks
|
||||
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
|
||||
|
||||
#ifndef HMX_PROFILE_H
|
||||
#define HMX_PROFILE_H
|
||||
|
||||
#include <HAP_perf.h>
|
||||
|
||||
// #define ENABLE_PROFILE_TIMERS
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
|
||||
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
|
||||
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
|
||||
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
|
||||
#else
|
||||
# define TIMER_DEFINE(name)
|
||||
# define TIMER_START(name)
|
||||
# define TIMER_STOP(name)
|
||||
# define TIMER_US(name) 0LL
|
||||
#endif
|
||||
|
||||
#endif // HMX_PROFILE_H
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
// HMX tile-level inline helpers (FP16 32x32 tile operations).
|
||||
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#define HMX_FP16_TILE_N_ROWS 32
|
||||
#define HMX_FP16_TILE_N_COLS 32
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
|
||||
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
|
||||
}
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// Load multiple contiguous tiles with :deep streaming.
|
||||
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
|
||||
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
|
||||
// boundary, otherwise the mxmem instruction will raise a precise bus error.
|
||||
// Callers must ensure their VTCM layout satisfies this constraint.
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1):deep\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Load a single activation+weight tile pair (no :deep streaming).
|
||||
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
|
||||
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
|
||||
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
|
||||
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
|
||||
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
|
||||
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
|
||||
const __fp16 *wt_tile) {
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1)\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(act_tile), "r"(2047),
|
||||
"r"(wt_tile), "r"(2047)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
|
||||
// Use the combined convert-and-store instruction (matches the reference
|
||||
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
|
||||
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
|
||||
asm volatile(
|
||||
"mxmem(%0, %1):after.hf = acc\n"
|
||||
:: "r"(out), "r"(0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Compute inner product of two vectors of tiles and store result.
|
||||
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
|
||||
const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
|
||||
hmx_consume_accumulator_fp16(out);
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
|
|
@ -30,6 +30,12 @@ struct htp_context {
|
|||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
|
||||
#endif
|
||||
};
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
|
|
|||
|
|
@ -32,13 +32,14 @@ enum htp_status {
|
|||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
|
|
@ -53,6 +54,10 @@ enum htp_op {
|
|||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
|
|
@ -69,6 +74,7 @@ enum htp_op {
|
|||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
@ -82,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
|
|||
return QK4_0;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return QK8_0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return QK4_NL;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return QK_MXFP4;
|
||||
default:
|
||||
|
|
@ -100,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
|
|||
return sizeof(block_q4_0);
|
||||
case HTP_TYPE_Q8_0:
|
||||
return sizeof(block_q8_0);
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return sizeof(block_iq4_nl);
|
||||
case HTP_TYPE_MXFP4:
|
||||
return sizeof(block_mxfp4);
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
|
|||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include "remote.idl"
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult stop();
|
||||
AEEResult enable_etm();
|
||||
AEEResult disable_etm();
|
||||
|
|
|
|||
|
|
@ -3,10 +3,15 @@
|
|||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
#define hvx_vmem(A) *((HVX_Vector *)(A))
|
||||
#define hvx_vmemu(A) *((HVX_UVector *)(A))
|
||||
|
||||
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
|
||||
|
|
@ -110,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
|
|||
return Q6_Q_and_QQ(p_exp, p_frac);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
|
||||
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
|
||||
|
|
@ -126,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
|||
return v;
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
|
||||
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
|
||||
}
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
|
||||
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
|
||||
}
|
||||
#else
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
|
||||
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
|
||||
}
|
||||
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
|
||||
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-floor.h"
|
||||
|
|
@ -16,8 +17,8 @@
|
|||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x41a00000) // 20.0
|
||||
#define EXP_RANGE_L (0xc1a00000) // -20.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
HVX_Vector z_qf32_v;
|
||||
|
|
@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
|||
|
||||
HVX_Vector temp_v = in_vec;
|
||||
|
||||
// Clamp inputs to (-20.0, 20.0)
|
||||
// Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
|
||||
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
|
||||
|
|
@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
|||
// normalize before every QFloat's vmpy
|
||||
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// z = x * x;
|
||||
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
|
||||
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// y = E4 + E5 * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
|
||||
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
|
||||
|
|
@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max
|
|||
return Q6_V_vmux_QVV(pred0, inf, out);
|
||||
}
|
||||
|
||||
static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||
static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
|
|
@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict
|
|||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
static const float kMaxExp = 88.7f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#define HVX_SIGMOID_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-inverse.h"
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@
|
|||
#include "htp-ops.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
#include "hmx-ops.h"
|
||||
#endif // HTP_HAS_HMX
|
||||
|
||||
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
struct htp_context * ctx;
|
||||
int err = 0;
|
||||
|
|
@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
|
|||
}
|
||||
|
||||
ctx->vtcm_inuse = true;
|
||||
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
|||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
|
|
@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
|||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (use_hmx) {
|
||||
ctx->vtcm_scratch_size = ctx->vtcm_size;
|
||||
ctx->hmx_enabled = 1;
|
||||
|
||||
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
|
||||
} else {
|
||||
// HMX disabled: skip HMX initialisation so the
|
||||
// dispatch loop falls through to the HVX compute paths.
|
||||
ctx->hmx_enabled = 0;
|
||||
ctx->vtcm_scratch_size = ctx->vtcm_size;
|
||||
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
|
|
@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
|||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_enabled) {
|
||||
ctx->hmx_enabled = 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
|
|
@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
|
|||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs,
|
||||
struct profile_data * prof) {
|
||||
// Prep response struct
|
||||
// Prep response struct (zero-init to clear cmp/unused union)
|
||||
struct htp_general_rsp rsp;
|
||||
memset(&rsp, 0, sizeof(rsp));
|
||||
rsp.op = op;
|
||||
rsp.status = status;
|
||||
rsp.prof_usecs = prof->usecs;
|
||||
|
|
@ -516,6 +545,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = op_repeat(&octx);
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
|
|
@ -1004,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// ---------------------------------------------------------------------------
|
||||
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
|
||||
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static void proc_hmx_matmul_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// HMX weight tile requires N to be 32-aligned.
|
||||
if (req->src0.ne[1] % 32 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
|
||||
req->src1.ne[2] * req->src1.ne[3] > 1);
|
||||
|
||||
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
|
||||
// batched quantised, but guard here too). F16 batched matmul is handled
|
||||
// by the dedicated wrapper in hmx-matmul-ops.c.
|
||||
if (is_batched &&
|
||||
req->src0.type != HTP_TYPE_F16) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX assumes contiguous row-major layout. Fall back for permuted
|
||||
// tensors where strides are non-monotonic (e.g. transposed KV cache).
|
||||
if (req->src0.nb[0] > req->src0.nb[1] ||
|
||||
req->src1.nb[0] > req->src1.nb[1]) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
const int m_total = (int) req->src1.ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
|
||||
if (m_hmx == 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
|
||||
// Other types (e.g. MXFP4) fall back to HVX.
|
||||
{
|
||||
uint32_t wtype = req->src0.type;
|
||||
if (wtype != HTP_TYPE_F16 &&
|
||||
wtype != HTP_TYPE_Q4_0 &&
|
||||
wtype != HTP_TYPE_Q8_0 &&
|
||||
wtype != HTP_TYPE_IQ4_NL) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
|
||||
// F16 HMX path requires K aligned to 32 (tile width).
|
||||
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
(void) n_bufs;
|
||||
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
|
||||
|
||||
// src0 = weights, src1 = activation, dst = output
|
||||
void * wgt = (void *) bufs[0].ptr;
|
||||
float * act = (float *) bufs[1].ptr;
|
||||
float * dst = (float *) bufs[2].ptr;
|
||||
|
||||
int k = (int) req->src0.ne[0]; // inner dimension
|
||||
int n = (int) req->src0.ne[1]; // weight columns
|
||||
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
int ret = -1;
|
||||
|
||||
const int ne02 = (int) req->src0.ne[2];
|
||||
const int ne03 = (int) req->src0.ne[3];
|
||||
const int ne12 = (int) req->src1.ne[2];
|
||||
const int ne13 = (int) req->src1.ne[3];
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
// permuted attention views they can be larger, so pass the real stride.
|
||||
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
|
||||
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
|
||||
|
||||
switch (req->src0.type) {
|
||||
case HTP_TYPE_F16:
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
.dst = dst,
|
||||
.activation = act,
|
||||
.permuted_weight = (const __fp16 *) wgt,
|
||||
.m = m_hmx,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
.weight_stride = weight_stride,
|
||||
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.ne12 = ne12,
|
||||
.ne13 = ne13,
|
||||
.src0_nb2 = req->src0.nb[2],
|
||||
.src0_nb3 = req->src0.nb[3],
|
||||
.src1_nb2 = req->src1.nb[2],
|
||||
.src1_nb3 = req->src1.nb[3],
|
||||
.dst_nb2 = req->dst.nb[2],
|
||||
.dst_nb3 = req->dst.nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
|
||||
(const __fp16 *) wgt,
|
||||
m_hmx, k, n,
|
||||
act_stride,
|
||||
weight_stride);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
|
||||
(const uint8_t *) wgt,
|
||||
m_hmx, k, n, (int) req->src0.type);
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret == 0) {
|
||||
rsp_status = HTP_STATUS_OK;
|
||||
} else {
|
||||
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
|
||||
vtcm_release(ctx);
|
||||
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0; // weights: unchanged
|
||||
octx.src1 = req->src1;
|
||||
octx.src1.ne[1] = m_tail; // only tail rows
|
||||
octx.dst = req->dst;
|
||||
octx.dst.ne[1] = m_tail; // only tail rows
|
||||
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
|
||||
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
|
||||
// is invalid.
|
||||
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
|
||||
octx.op = req->op;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
|
||||
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
|
||||
|
||||
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
|
||||
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
|
||||
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
uint32_t hvx_ret = op_matmul(&octx);
|
||||
vtcm_release(ctx);
|
||||
if (hvx_ret != HTP_STATUS_OK) {
|
||||
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
|
||||
rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
}
|
||||
} else {
|
||||
rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
}
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
#endif // HTP_HAS_HMX
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
|
|
@ -1056,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
FARF(ERROR, "Bad matmul-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_enabled) {
|
||||
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
}
|
||||
break;
|
||||
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
|
|
@ -1090,6 +1363,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
|
||||
case HTP_OP_SQR:
|
||||
case HTP_OP_SQRT:
|
||||
case HTP_OP_UNARY_NEG:
|
||||
case HTP_OP_UNARY_EXP:
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
|
|
@ -1175,6 +1452,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_REPEAT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad repeat-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_repeat_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ARGSORT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad argsort-req buffer list");
|
||||
|
|
|
|||
|
|
@ -0,0 +1,148 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_repeat_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
uint32_t nr0;
|
||||
uint32_t nr1;
|
||||
uint32_t nr2;
|
||||
uint32_t nr3;
|
||||
|
||||
uint32_t nrows_per_thread;
|
||||
uint32_t total_dst_rows; // ne1 * ne2 * ne3
|
||||
|
||||
size_t type_size;
|
||||
};
|
||||
|
||||
static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t ne00 = src->ne[0];
|
||||
const uint32_t ne01 = src->ne[1];
|
||||
const uint32_t ne02 = src->ne[2];
|
||||
const uint32_t ne03 = src->ne[3];
|
||||
|
||||
const uint32_t nb00 = src->nb[0];
|
||||
const uint32_t nb01 = src->nb[1];
|
||||
const uint32_t nb02 = src->nb[2];
|
||||
const uint32_t nb03 = src->nb[3];
|
||||
|
||||
const uint32_t ne0 = dst->ne[0];
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
const uint32_t ne2 = dst->ne[2];
|
||||
const uint32_t ne3 = dst->ne[3];
|
||||
|
||||
const uint32_t nb0 = dst->nb[0];
|
||||
const uint32_t nb1 = dst->nb[1];
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
const uint32_t nr0 = rctx->nr0;
|
||||
const uint32_t nr1 = rctx->nr1;
|
||||
const uint32_t nr2 = rctx->nr2;
|
||||
const uint32_t nr3 = rctx->nr3;
|
||||
|
||||
const size_t row_bytes = ne00 * rctx->type_size;
|
||||
|
||||
const uint32_t row_start = rctx->nrows_per_thread * ith;
|
||||
const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
|
||||
// Decompose flat dst row index into (i1, i2, i3)
|
||||
const uint32_t i1 = dst_row % ne1;
|
||||
const uint32_t i2 = (dst_row / ne1) % ne2;
|
||||
const uint32_t i3 = dst_row / (ne1 * ne2);
|
||||
|
||||
// Map to source indices (tiling)
|
||||
const uint32_t k1 = i1 % ne01;
|
||||
const uint32_t k2 = i2 % ne02;
|
||||
const uint32_t k3 = i3 % ne03;
|
||||
|
||||
const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03;
|
||||
uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
|
||||
// Tile along dimension 0
|
||||
for (uint32_t i0 = 0; i0 < nr0; i0++) {
|
||||
uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0;
|
||||
memcpy(dst_ptr, src_row, row_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
|
||||
ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_repeat(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Validate that dst dims are multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0 ||
|
||||
dst->ne[1] % src0->ne[1] != 0 ||
|
||||
dst->ne[2] % src0->ne[2] != 0 ||
|
||||
dst->ne[3] % src0->ne[3] != 0) {
|
||||
FARF(ERROR, "repeat: dst dims must be multiples of src dims\n");
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
|
||||
size_t type_size;
|
||||
switch (src0->type) {
|
||||
case HTP_TYPE_F32: type_size = 4; break;
|
||||
case HTP_TYPE_F16: type_size = 2; break;
|
||||
default:
|
||||
FARF(ERROR, "repeat: unsupported type %u\n", src0->type);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows);
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
struct htp_repeat_context rctx = {
|
||||
.octx = octx,
|
||||
.nr0 = dst->ne[0] / src0->ne[0],
|
||||
.nr1 = dst->ne[1] / src0->ne[1],
|
||||
.nr2 = dst->ne[2] / src0->ne[2],
|
||||
.nr3 = dst->ne[3] / src0->ne[3],
|
||||
.nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
|
||||
.total_dst_rows = total_dst_rows,
|
||||
.type_size = type_size,
|
||||
};
|
||||
|
||||
FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n",
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3);
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
|||
const float max) {
|
||||
hvx_sub_scalar_f32(spad, src, max, num_elems);
|
||||
|
||||
hvx_exp_f32(spad, dst, num_elems, false);
|
||||
hvx_exp_f32(dst, spad, num_elems, false);
|
||||
|
||||
float sum = hvx_reduce_sum_f32(dst, num_elems);
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-exp.h"
|
||||
#include "hvx-sigmoid.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
|
@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void neg_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
static void exp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_exp_f32(dst_local, src_local, row_elems, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void sigmoid_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void softplus_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
// softplus(x) = log(1 + exp(x))
|
||||
// Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
||||
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
||||
|
||||
for (uint32_t i = 0; i < row_elems; i++) {
|
||||
float x = src_f[i];
|
||||
// For x > 20: softplus(x) ≈ x (avoids exp overflow)
|
||||
dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
|
|
@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|||
case HTP_OP_SQRT:
|
||||
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_NEG:
|
||||
neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_EXP:
|
||||
exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
case HTP_OP_SQRT:
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_NEG:
|
||||
op_type = "neg-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_EXP:
|
||||
op_type = "exp-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_SIGMOID:
|
||||
op_type = "sigmoid-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_SOFTPLUS:
|
||||
op_type = "softplus-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
|
|
|||
|
|
@ -53,9 +53,6 @@ endif()
|
|||
|
||||
message(STATUS "HIP and hipBLAS found")
|
||||
|
||||
# Workaround old compilers
|
||||
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
|
||||
|
||||
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
||||
|
||||
|
|
@ -132,6 +129,11 @@ endif()
|
|||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
# CMake on Windows doesn't support the HIP language yet.
|
||||
# Therefore we workaround debug build's failure on HIP backend this way.
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g")
|
||||
endif()
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
else()
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include "dmmv.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "fattn.hpp"
|
||||
#include "gated_delta_net.hpp"
|
||||
#include "gla.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "mmq.hpp"
|
||||
|
|
@ -31,6 +32,7 @@
|
|||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "pad.hpp"
|
||||
#include "pad_reflect_1d.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "roll.hpp"
|
||||
|
|
@ -39,8 +41,8 @@
|
|||
#include "ssm_conv.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "upscale.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "pad_reflect_1d.hpp"
|
||||
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
|||
|
|
@ -294,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl:
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
||||
int index = item_ct1.get_local_id(0) +
|
||||
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
||||
if (index >= ne10 * ne11 * ne12 * ne13) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int i10 = index % ne10;
|
||||
int i11 = (index / ne10) % ne11;
|
||||
int i12 = (index / (ne10 * ne11)) % ne12;
|
||||
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
||||
|
||||
int i00 = static_cast<int>(i10 / sf0);
|
||||
int i01 = static_cast<int>(i11 / sf1);
|
||||
int i02 = static_cast<int>(i12 / sf2);
|
||||
int i03 = static_cast<int>(i13 / sf3);
|
||||
|
||||
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
|
|
@ -392,20 +368,6 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, queue_ptr stream) {
|
||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
|
||||
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
|
|
@ -505,42 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
|||
}
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
|
||||
const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
|
||||
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
|
||||
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
||||
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
||||
main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
||||
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
||||
main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
static inline void ggml_sycl_op_unary(
|
||||
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
|
||||
|
|
@ -784,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
|||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
|
||||
int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
|
||||
queue_ptr stream) {
|
||||
ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
float min_val;
|
||||
float max_val;
|
||||
|
|
@ -1131,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
ggml_sycl_op_sqr(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_clamp(ctx, dst);
|
||||
|
|
|
|||
|
|
@ -71,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|||
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@
|
|||
#include "ggml-sycl/backend.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/element_wise.hpp"
|
||||
#include "ggml-sycl/gated_delta_net.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/norm.hpp"
|
||||
|
|
@ -4863,9 +4862,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
return true;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,410 @@
|
|||
#include "upscale.hpp"
|
||||
|
||||
static void upscale_f32(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
int index = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (index >= ne10 * ne11 * ne12 * ne13) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i10 = index % ne10;
|
||||
int i11 = (index / ne10) % ne11;
|
||||
int i12 = (index / (ne10 * ne11)) % ne12;
|
||||
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
||||
|
||||
int i00 = i10 / sf0;
|
||||
int i01 = i11 / sf1;
|
||||
int i02 = i12 / sf2;
|
||||
int i03 = i13 / sf3;
|
||||
|
||||
dst[index] = *((const float*)((const char*)x + i03 * nb03 + i02 * nb02 +
|
||||
i01 * nb01 + i00 * nb00));
|
||||
}
|
||||
|
||||
static void upscale_f32_bilinear(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t index = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
int y0_src = (int) sycl::floor((float) y_src_f);
|
||||
int y1_src = y0_src + 1;
|
||||
|
||||
y0_src = sycl::max(0, sycl::min(y0_src, ne01_src - 1));
|
||||
y1_src = sycl::max(0, sycl::min(y1_src, ne01_src - 1));
|
||||
|
||||
float dy = y_src_f - (float)y0_src;
|
||||
dy = sycl::max(0.0f, sycl::min(dy, 1.0f));
|
||||
|
||||
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
int x0_src = (int) sycl::floor(x_src_f);
|
||||
int x1_src = x0_src + 1;
|
||||
|
||||
x0_src = sycl::max(0, sycl::min(x0_src, ne00_src - 1));
|
||||
x1_src = sycl::max(0, sycl::min(x1_src, ne00_src - 1));
|
||||
|
||||
float dx = x_src_f - (float)x0_src;
|
||||
dx = sycl::max(0.0f, sycl::min(dx, 1.0f));
|
||||
|
||||
const float* p_a =
|
||||
(const float*)((const char*)x + (int64_t)x0_src * nb00 +
|
||||
(int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 +
|
||||
(int64_t)i03_src * nb03);
|
||||
const float* p_b =
|
||||
(const float*)((const char*)x + (int64_t)x1_src * nb00 +
|
||||
(int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 +
|
||||
(int64_t)i03_src * nb03);
|
||||
const float* p_c =
|
||||
(const float*)((const char*)x + (int64_t)x0_src * nb00 +
|
||||
(int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 +
|
||||
(int64_t)i03_src * nb03);
|
||||
const float* p_d =
|
||||
(const float*)((const char*)x + (int64_t)x1_src * nb00 +
|
||||
(int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 +
|
||||
(int64_t)i03_src * nb03);
|
||||
|
||||
const float val_a = *p_a;
|
||||
const float val_b = *p_b;
|
||||
const float val_c = *p_c;
|
||||
const float val_d = *p_d;
|
||||
|
||||
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||
val_b * dx * (1.0f - dy) +
|
||||
val_c * (1.0f - dx) * dy +
|
||||
val_d * dx * dy;
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
||||
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
||||
static void upscale_f32_bilinear_antialias(const float * src0,
|
||||
float * dst,
|
||||
const int nb00,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int ne00_src,
|
||||
const int ne01_src,
|
||||
const int ne10_dst,
|
||||
const int ne11_dst,
|
||||
const int ne12_dst,
|
||||
const int ne13_dst,
|
||||
const float sf0,
|
||||
const float sf1,
|
||||
const float sf2,
|
||||
const float sf3,
|
||||
const float pixel_offset) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t index = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y = ((float)i11_dst + pixel_offset) / sf1;
|
||||
const float x = ((float)i10_dst + pixel_offset) / sf0;
|
||||
|
||||
// support and invscale, minimum 1 pixel for bilinear
|
||||
const float support1 = sycl::max(1.0f / sf1, 1.0f);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
const float support0 = sycl::max(1.0f / sf0, 1.0f);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
|
||||
// the range of source pixels that contribute
|
||||
const int64_t x_min = sycl::max(int64_t(0), int64_t(x - support0 + pixel_offset));
|
||||
const int64_t x_max = sycl::min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
|
||||
const int64_t y_min = sycl::max(int64_t(0), int64_t(y - support1 + pixel_offset));
|
||||
const int64_t y_max = sycl::min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
|
||||
|
||||
// bilinear filter with antialiasing
|
||||
float val = 0.0f;
|
||||
float total_weight = 0.0f;
|
||||
|
||||
auto triangle_filter = [](float x) -> float {
|
||||
return sycl::max(1.0f - sycl::fabs(x), 0.0f);
|
||||
};
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; sy++) {
|
||||
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
||||
|
||||
for (int64_t sx = x_min; sx < x_max; sx++) {
|
||||
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
||||
const float weight = weight_x * weight_y;
|
||||
|
||||
if (weight <= 0.0f) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float pixel =
|
||||
*(const float*)((const char*)src0 + sx * nb00 + sy * nb01 +
|
||||
i02_src * nb02 + i03_src * nb03);
|
||||
val += pixel * weight;
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0f) {
|
||||
val /= total_weight;
|
||||
}
|
||||
|
||||
dst[index] = val;
|
||||
}
|
||||
|
||||
namespace bicubic_interpolation {
|
||||
static float weight1(float x, const float &a) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
||||
static float weight2(float x, const float &a) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
||||
|
||||
static float bicubic(float p0, float p1, float p2, float p3, float x, float a) {
|
||||
const float w0 = weight2(x + 1, a);
|
||||
const float w1 = weight1(x + 0, a);
|
||||
const float w2 = weight1(1 - x, a);
|
||||
const float w3 = weight2(2 - x, a);
|
||||
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
static void upscale_f32_bicubic(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const float a = -0.75f;
|
||||
using bicubic_interpolation::bicubic;
|
||||
|
||||
const int64_t index = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
const int64_t dst_total_elements =
|
||||
ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
const int y0_src = (int) sycl::floor((float) y_src_f);
|
||||
const float dy = y_src_f - (float)y0_src;
|
||||
|
||||
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
const int x0_src = (int) sycl::floor((float) x_src_f);
|
||||
const float dx = x_src_f - (float)x0_src;
|
||||
|
||||
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
|
||||
|
||||
auto load = [=](int x_off, int y_off) -> float {
|
||||
int i00_src = sycl::max(0, sycl::min(x0_src + x_off, ne00_src - 1));
|
||||
int i01_src = sycl::max(0, sycl::min(y0_src + y_off, ne01_src - 1));
|
||||
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
|
||||
};
|
||||
|
||||
const float result = bicubic(
|
||||
bicubic(load(-1, -1), load(0, -1), load(1, -1), load(2, -1), dx, a),
|
||||
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx, a),
|
||||
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx, a),
|
||||
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx, a),
|
||||
dy,
|
||||
a);
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
static void upscale_f32_sycl(const float * x,
|
||||
float * dst,
|
||||
const int nb00,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const float sf0,
|
||||
const float sf1,
|
||||
const float sf2,
|
||||
const float sf3,
|
||||
dpct::queue_ptr stream) {
|
||||
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
});
|
||||
}
|
||||
|
||||
static void upscale_f32_bilinear_sycl(const float * x,
|
||||
float * dst,
|
||||
const int nb00,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int ne00_src,
|
||||
const int ne01_src,
|
||||
const int ne10_dst,
|
||||
const int ne11_dst,
|
||||
const int ne12_dst,
|
||||
const int ne13_dst,
|
||||
const float sf0,
|
||||
const float sf1,
|
||||
const float sf2,
|
||||
const float sf3,
|
||||
const float pixel_offset,
|
||||
bool antialias,
|
||||
dpct::queue_ptr stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
if (antialias) {
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
upscale_f32_bilinear_antialias(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
upscale_f32_bilinear(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst,
|
||||
ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void upscale_f32_bicubic_sycl(const float * x,
|
||||
float * dst,
|
||||
const int nb00,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int ne00_src,
|
||||
const int ne01_src,
|
||||
const int ne10_dst,
|
||||
const int ne11_dst,
|
||||
const int ne12_dst,
|
||||
const int ne13_dst,
|
||||
const float sf0,
|
||||
const float sf1,
|
||||
const float sf2,
|
||||
const float sf3,
|
||||
const float pixel_offset,
|
||||
dpct::queue_ptr stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
{
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
upscale_f32_bicubic(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int mode_flags = dst->op_params[0];
|
||||
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
|
||||
|
||||
float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1
|
||||
? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1)
|
||||
: sf0;
|
||||
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1
|
||||
? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1)
|
||||
: sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_sycl(
|
||||
src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
upscale_f32_bilinear_sycl(
|
||||
src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
||||
upscale_f32_bicubic_sycl(
|
||||
src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "dpct/helper.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_UPSCALE_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -191,6 +191,7 @@ struct vk_queue;
|
|||
|
||||
struct vk_command_buffer {
|
||||
vk::CommandBuffer buf;
|
||||
uint64_t use_counter = 0;
|
||||
bool in_use = false;
|
||||
};
|
||||
|
||||
|
|
@ -938,21 +939,26 @@ struct vk_subbuffer {
|
|||
}
|
||||
};
|
||||
|
||||
// vk_event is used for the event-related backend interfaces. It uses 'event' for
|
||||
// event_wait and 'fence' for event_synchronize. Polling on an event for
|
||||
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
|
||||
// and would lead to validation errors.
|
||||
struct vk_event {
|
||||
vk::Event event;
|
||||
vk::Fence fence;
|
||||
vk_command_buffer* cmd_buffer = nullptr;
|
||||
};
|
||||
|
||||
struct vk_semaphore {
|
||||
vk::Semaphore s;
|
||||
uint64_t value;
|
||||
};
|
||||
|
||||
// vk_event is used for the event-related backend interfaces. It uses vk::Events for
|
||||
// event_wait and a timeline semaphore for event_synchronize. Polling on an event for
|
||||
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
|
||||
// and would lead to validation errors.
|
||||
struct vk_event {
|
||||
std::vector<vk::Event> events_free; // Events available for reuse
|
||||
std::vector<vk::Event> events_submitted; // Events that are fully submitted and can be reused on next synchronize
|
||||
vk::Event event;
|
||||
bool has_event;
|
||||
|
||||
vk_semaphore tl_semaphore;
|
||||
vk_command_buffer* cmd_buffer = nullptr;
|
||||
uint64_t cmd_buffer_use_counter = 0;
|
||||
};
|
||||
|
||||
struct vk_submission {
|
||||
vk_command_buffer* buffer = nullptr;
|
||||
std::vector<vk_semaphore> wait_semaphores;
|
||||
|
|
@ -2319,7 +2325,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman
|
|||
vk::CommandBufferLevel::ePrimary,
|
||||
1);
|
||||
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
|
||||
p.cmd_buffers.push_back({ cmd_buffers.front(), true });
|
||||
p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true });
|
||||
return &p.cmd_buffers[p.cmd_buffers.size()-1];
|
||||
}
|
||||
|
||||
|
|
@ -2788,6 +2794,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
|
|||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) {
|
||||
VK_LOG_DEBUG("ggml_vk_set_event()");
|
||||
|
||||
ctx->s->buffer->buf.resetEvent(
|
||||
event,
|
||||
ctx->p->q->stage_flags
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
|
||||
VK_LOG_DEBUG("ggml_vk_set_event()");
|
||||
|
||||
|
|
@ -4589,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
||||
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
||||
};
|
||||
const bool use_subgroup_reduce = device->subgroup_arithmetic;
|
||||
for (uint32_t si = 0; si < 3; si++) {
|
||||
const uint32_t S_V = gdn_sizes[si];
|
||||
GGML_ASSERT(is_pow2(S_V));
|
||||
|
||||
uint32_t lanes_per_column;
|
||||
if (S_V >= 128u && device->subgroup_clustered) {
|
||||
lanes_per_column = 8u;
|
||||
} else {
|
||||
// Use largest power-of-two that divides both S_V and subgroup_size so that
|
||||
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
|
||||
// This means we don't need extra bounds checking logic in the shader.
|
||||
lanes_per_column = std::min(S_V, device->subgroup_size);
|
||||
}
|
||||
|
||||
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
|
||||
size_t gdn_len;
|
||||
const void * gdn_data;
|
||||
if (use_subgroup_reduce && need_clustered_shader) {
|
||||
gdn_len = gated_delta_net_f32_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_data;
|
||||
} else if (use_subgroup_reduce) {
|
||||
gdn_len = gated_delta_net_f32_nocluster_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
|
||||
} else {
|
||||
gdn_len = gated_delta_net_f32_shmem_len;
|
||||
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
|
||||
}
|
||||
|
||||
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
|
||||
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
|
||||
|
||||
for (uint32_t kda = 0; kda < 2; kda++) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
||||
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
|
||||
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
||||
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
||||
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4981,8 +5026,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
||||
// Try to find a non-graphics compute queue and transfer-focused queues
|
||||
// On AMD, the graphics queue seems to be faster, so don't avoid it
|
||||
const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
|
||||
// Allow overriding avoiding the graphics queue because it can increase performance on RADV
|
||||
const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr);
|
||||
const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
|
||||
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
|
||||
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
|
||||
|
||||
|
|
@ -5443,11 +5489,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
ggml_vk_load_shaders(device);
|
||||
|
||||
// Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled
|
||||
const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue;
|
||||
|
||||
if (!device->single_queue) {
|
||||
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
|
||||
ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
|
||||
|
||||
device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
||||
device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
||||
} else {
|
||||
// TODO: Use pointer or reference to avoid copy
|
||||
device->transfer_queue.copyFrom(device->compute_queue);
|
||||
|
|
@ -6392,6 +6441,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
|
|||
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
|
||||
for (auto& cmd_buffer : pool.cmd_buffers) {
|
||||
if (!cmd_buffer.in_use) {
|
||||
cmd_buffer.use_counter++;
|
||||
cmd_buffer.in_use = true;
|
||||
return &cmd_buffer;
|
||||
}
|
||||
|
|
@ -6496,15 +6546,16 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
|
|||
}
|
||||
|
||||
static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
|
||||
vk_context result;
|
||||
if (!ctx->compute_ctx.expired()) {
|
||||
return ctx->compute_ctx.lock();
|
||||
result = ctx->compute_ctx.lock();
|
||||
} else {
|
||||
result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
|
||||
ctx->compute_ctx = result;
|
||||
ggml_vk_ctx_begin(ctx->device, result);
|
||||
}
|
||||
|
||||
vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
|
||||
ctx->compute_ctx = result;
|
||||
ggml_vk_ctx_begin(ctx->device, result);
|
||||
|
||||
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
||||
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
|
||||
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
|
||||
|
|
@ -7625,20 +7676,14 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|||
return true;
|
||||
}
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
if (k < 2048) {
|
||||
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Intel Windows proprietary driver MMVQ performance is worse than fp16, see
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17628
|
||||
return false;
|
||||
}
|
||||
|
||||
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
||||
// Intel Windows proprietary driver tuning
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
if (k < 2048) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
|
|
@ -10423,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, { H, n_seqs, 1u });
|
||||
pc, { H, n_seqs, S_v });
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
|
|
@ -13797,6 +13842,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
|||
ctx->submit_pending = false;
|
||||
if (cmd_buf) {
|
||||
cmd_buf->in_use = false;
|
||||
cmd_buf->buf.reset();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -14858,18 +14904,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
|||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
|
||||
|
||||
// the backend interface doesn't have an explicit reset, so reset it here
|
||||
// before we record the command to set it
|
||||
ctx->device->device.resetEvent(vkev->event);
|
||||
ctx->device->device.resetFences({ vkev->fence });
|
||||
if (vkev->has_event) {
|
||||
// Move existing event into submitted
|
||||
vkev->events_submitted.push_back(vkev->event);
|
||||
}
|
||||
|
||||
// Grab the next event and record it, create one if necessary
|
||||
if (vkev->events_free.empty()) {
|
||||
vkev->event = ctx->device->device.createEvent({});
|
||||
} else {
|
||||
vkev->event = vkev->events_free.back();
|
||||
vkev->events_free.pop_back();
|
||||
}
|
||||
|
||||
vkev->has_event = true;
|
||||
|
||||
ggml_vk_set_event(compute_ctx, vkev->event);
|
||||
|
||||
vkev->tl_semaphore.value++;
|
||||
compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore);
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
|
||||
ggml_vk_submit(compute_ctx, {vkev->fence});
|
||||
ggml_vk_submit(compute_ctx, {});
|
||||
ctx->submit_pending = true;
|
||||
vkev->cmd_buffer = cmd_buf;
|
||||
vkev->cmd_buffer_use_counter = cmd_buf->use_counter;
|
||||
ctx->compute_ctx.reset();
|
||||
}
|
||||
|
||||
|
|
@ -14880,9 +14939,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
|||
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
|
||||
ggml_vk_wait_events(compute_ctx, {vkev->event});
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
ctx->compute_ctx.reset();
|
||||
if (vkev->has_event) {
|
||||
// Wait for latest event
|
||||
ggml_vk_wait_events(compute_ctx, { vkev->event });
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: enable async and synchronize
|
||||
|
|
@ -15672,10 +15732,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// The event/fence is expected to initially be in the signaled state.
|
||||
vkev->event = device->device.createEvent({});
|
||||
vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
|
||||
device->device.setEvent(vkev->event);
|
||||
// No events initially, they get created on demand
|
||||
vkev->has_event = false;
|
||||
|
||||
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
|
||||
vk::SemaphoreCreateInfo ci{};
|
||||
ci.setPNext(&tci);
|
||||
vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 };
|
||||
|
||||
return new ggml_backend_event {
|
||||
/* .device = */ dev,
|
||||
|
|
@ -15689,8 +15752,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
|
|||
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
device->device.destroyFence(vkev->fence);
|
||||
device->device.destroyEvent(vkev->event);
|
||||
device->device.destroySemaphore(vkev->tl_semaphore.s);
|
||||
for (auto& event : vkev->events_free) {
|
||||
device->device.destroyEvent(event);
|
||||
}
|
||||
for (auto& event : vkev->events_submitted) {
|
||||
device->device.destroyEvent(event);
|
||||
}
|
||||
if (vkev->has_event) {
|
||||
device->device.destroyEvent(vkev->event);
|
||||
}
|
||||
delete vkev;
|
||||
delete event;
|
||||
}
|
||||
|
|
@ -15701,10 +15772,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|||
auto device = ggml_vk_get_device(ctx->device);
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
||||
// Finished using current command buffer so we flag for reuse
|
||||
if (vkev->cmd_buffer) {
|
||||
vkev->cmd_buffer->in_use = false;
|
||||
// Only do something if the event has actually been used
|
||||
if (vkev->has_event) {
|
||||
vk::Semaphore sem = vkev->tl_semaphore.s;
|
||||
uint64_t val = vkev->tl_semaphore.value;
|
||||
vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val};
|
||||
VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize");
|
||||
|
||||
// Reset and move submitted events
|
||||
for (auto& event : vkev->events_submitted) {
|
||||
device->device.resetEvent(event);
|
||||
}
|
||||
vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end());
|
||||
vkev->events_submitted.clear();
|
||||
|
||||
// Finished using current command buffer so we flag for reuse
|
||||
if (vkev->cmd_buffer) {
|
||||
// Only flag for reuse if it hasn't been reused already
|
||||
if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) {
|
||||
vkev->cmd_buffer->in_use = false;
|
||||
vkev->cmd_buffer->buf.reset();
|
||||
}
|
||||
vkev->cmd_buffer = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -15958,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
|||
case 0xE20C: // B570
|
||||
return 18;
|
||||
case 0xE20B: // B580
|
||||
case 0xE211: // Pro B60
|
||||
return 20;
|
||||
default:
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ void main() {
|
|||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -270,7 +270,7 @@ void main() {
|
|||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
|
||||
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,25 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
#extension GL_KHR_shader_subgroup_clustered : enable
|
||||
#endif
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
|
||||
// so no bounds checking is needed.
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint KDA = 0;
|
||||
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
|
||||
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
|
||||
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
|
||||
|
||||
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint H;
|
||||
|
|
@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
|
|||
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
|
||||
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
|
||||
|
||||
shared FLOAT_TYPE s_k[S_V];
|
||||
shared FLOAT_TYPE s_q[S_V];
|
||||
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
|
||||
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
|
||||
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
|
||||
|
||||
// This does a reduction across groups of LANES_PER_COLUMN
|
||||
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
temp[lane] = partial;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
|
||||
FLOAT_TYPE other = temp[lane ^ s];
|
||||
barrier();
|
||||
temp[lane] += other;
|
||||
barrier();
|
||||
}
|
||||
const FLOAT_TYPE result = temp[lane];
|
||||
barrier();
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
|
||||
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
|
||||
switch (LANES_PER_COLUMN) {
|
||||
case 1u:
|
||||
return partial;
|
||||
#if USE_SUBGROUP_CLUSTERED
|
||||
// Workaround for GLSL requiring a literal constant for the cluster size.
|
||||
// The branches should all fold away.
|
||||
case 2u:
|
||||
return subgroupClusteredAdd(partial, 2u);
|
||||
case 4u:
|
||||
return subgroupClusteredAdd(partial, 4u);
|
||||
case 8u:
|
||||
return subgroupClusteredAdd(partial, 8u);
|
||||
case 16u:
|
||||
return subgroupClusteredAdd(partial, 16u);
|
||||
case 32u:
|
||||
return subgroupClusteredAdd(partial, 32u);
|
||||
case 64u:
|
||||
return subgroupClusteredAdd(partial, 64u);
|
||||
#endif
|
||||
default:
|
||||
#if USE_SUBGROUP_ADD
|
||||
return subgroupAdd(partial);
|
||||
#else
|
||||
return reduce_add_shmem(partial);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint head_id = gl_WorkGroupID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const uint seq_id = gl_WorkGroupID.y;
|
||||
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
|
||||
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
|
||||
|
||||
const uint iq1 = head_id % neq1;
|
||||
const uint iq3 = seq_id / rq3;
|
||||
|
|
@ -42,9 +103,9 @@ void main() {
|
|||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
||||
|
||||
FLOAT_TYPE state[S_V];
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
|
||||
FLOAT_TYPE s_shard[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
}
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
|
@ -53,76 +114,56 @@ void main() {
|
|||
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const uint k_off = q_off;
|
||||
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
|
||||
|
||||
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
|
||||
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
|
||||
|
||||
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
|
||||
|
||||
if (KDA != 0) {
|
||||
const uint g_base = gb_off * S_V;
|
||||
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
|
||||
|
||||
FLOAT_TYPE k_reg[ROWS_PER_LANE];
|
||||
FLOAT_TYPE q_reg[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
|
||||
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE g_exp[ROWS_PER_LANE];
|
||||
if (KDA == 0) {
|
||||
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
|
||||
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
kv_col += dot(
|
||||
vec4(state[i], state[i+1], state[i+2], state[i+3]),
|
||||
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
|
||||
);
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
g_exp[r] = g_val;
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = g_val * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
} else {
|
||||
FLOAT_TYPE kv_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
kv_col += dot(gv * sv, kv);
|
||||
const uint g_base = gb_off * S_V;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
const uint i = r * LANES_PER_COLUMN + lane;
|
||||
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
|
||||
}
|
||||
}
|
||||
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
|
||||
|
||||
FLOAT_TYPE attn_col = 0.0;
|
||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
||||
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
|
||||
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
|
||||
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
|
||||
sv = gv * sv + kv * delta_col;
|
||||
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
|
||||
FLOAT_TYPE kv_shard = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
|
||||
}
|
||||
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
|
||||
|
||||
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
|
||||
}
|
||||
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
FLOAT_TYPE attn_partial = 0.0;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
|
||||
attn_partial += s_shard[r] * q_reg[r];
|
||||
}
|
||||
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
|
||||
|
||||
if (lane == 0) {
|
||||
data_dst[attn_off + col] = attn_col * scale;
|
||||
}
|
||||
|
||||
attn_off += S_V * H;
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
data_dst[s_off + state_base + col * S_V + i] = state[i];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint ib32 = (idx % 64) / 8; // 0..7
|
||||
const uint iq = 4 * ib32 + (idx % 4);
|
||||
|
||||
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
||||
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
||||
const uint qshift = (idx & 8) >> 1;
|
||||
u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
|
||||
const uint qshift = idx & 4;
|
||||
u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F);
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||
const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
|
|
|||
|
|
@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
||||
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
|
|
@ -987,7 +987,9 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
|
||||
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
|
||||
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
|
|||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t block_size;
|
||||
uint32_t tokens_per_wg;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
|
|
@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
|
|||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
/** Set **/
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key {
|
||||
ggml_type type;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
|
||||
return type == other.type && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Get Rows **/
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
|
|
@ -151,6 +176,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** Row Norm **/
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key {
|
||||
ggml_op op;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
||||
return op == other.op && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Pad **/
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
|
|
@ -166,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** Solve Tri **/
|
||||
struct ggml_webgpu_solve_tri_pipeline_key {
|
||||
int type;
|
||||
int n;
|
||||
int k;
|
||||
|
||||
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
|
||||
return type == other.type && n == other.n && k == other.k;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_solve_tri_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.n);
|
||||
ggml_webgpu_hash_combine(seed, key.k);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** SSM Conv **/
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key {
|
||||
int type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
|
||||
return type == other.type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
/** Gated Delta Net **/
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
||||
int type;
|
||||
int s_v;
|
||||
int kda;
|
||||
|
||||
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
|
||||
return type == other.type && s_v == other.s_v && kda == other.kda;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.s_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kda);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
|
|
@ -244,13 +350,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
|||
/** Unary **/
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
ggml_tri_type ttype; // only used for GGML_OP_TRI
|
||||
|
||||
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
|
||||
ttype == other.ttype;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -261,6 +369,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
|||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.is_unary);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.ttype);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
|
@ -435,20 +544,30 @@ class ggml_webgpu_shader_lib {
|
|||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
||||
row_norm_pipelines; // op/inplace
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
get_rows_pipelines; // src_type, vectorized
|
||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||
unary_pipelines; // type/op/inplace
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
||||
solve_tri_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||
ssm_conv_pipelines; // type/vectorized
|
||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||
gated_delta_net_pipelines; // type/S_v/kda
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
||||
pad_pipelines; // circular/non-circular
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||
concat_pipelines; // type
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||
repeat_pipelines; // type
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
|
|
@ -462,6 +581,7 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
|
|
@ -479,6 +599,45 @@ class ggml_webgpu_shader_lib {
|
|||
return sum_rows_pipelines[1];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
if (it != row_norm_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
std::string variant;
|
||||
|
||||
switch (key.op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
defines.push_back("RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("L2_NORM");
|
||||
variant = "l2_norm";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported op for row_norm shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return row_norm_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
bool vec4 = context.src0->ne[0] % 4 == 0;
|
||||
|
||||
|
|
@ -546,6 +705,46 @@ class ggml_webgpu_shader_lib {
|
|||
return set_rows_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("TYPE_I32");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for set shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
set_pipelines[key] = pipeline;
|
||||
return set_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = cumsum_pipelines.find(1);
|
||||
if (it != cumsum_pipelines.end()) {
|
||||
|
|
@ -632,6 +831,7 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
if (key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
|
|
@ -646,6 +846,7 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
|
@ -653,6 +854,7 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
|
|
@ -731,6 +933,128 @@ class ggml_webgpu_shader_lib {
|
|||
return scale_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_solve_tri_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.n = (int) context.src0->ne[0],
|
||||
.k = (int) context.src1->ne[0],
|
||||
};
|
||||
|
||||
auto it = solve_tri_pipelines.find(key);
|
||||
if (it != solve_tri_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "solve_tri";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for solve_tri shader");
|
||||
}
|
||||
|
||||
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
|
||||
const uint32_t k_tile = wg_size;
|
||||
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
|
||||
|
||||
defines.push_back(std::string("N=") + std::to_string(key.n));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
|
||||
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
solve_tri_pipelines[key] = pipeline;
|
||||
return solve_tri_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_conv_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.vectorized = context.src1->ne[0] == 4,
|
||||
};
|
||||
|
||||
auto it = ssm_conv_pipelines.find(key);
|
||||
if (it != ssm_conv_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "ssm_conv";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ssm_conv shader");
|
||||
}
|
||||
|
||||
if (key.vectorized) {
|
||||
defines.push_back("VECTORIZED");
|
||||
variant += "_vec4";
|
||||
}
|
||||
|
||||
constexpr uint32_t block_size = 32u;
|
||||
constexpr uint32_t tokens_per_wg = 8u;
|
||||
|
||||
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
|
||||
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
|
||||
decisions->block_size = block_size;
|
||||
decisions->tokens_per_wg = tokens_per_wg;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_conv_pipelines[key] = pipeline;
|
||||
return ssm_conv_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.s_v = (int) context.src2->ne[0],
|
||||
.kda = context.src3->ne[0] == context.src2->ne[0],
|
||||
};
|
||||
|
||||
auto it = gated_delta_net_pipelines.find(key);
|
||||
if (it != gated_delta_net_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "gated_delta_net";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for gated_delta_net shader");
|
||||
}
|
||||
|
||||
if (key.kda) {
|
||||
defines.push_back("KDA");
|
||||
variant += "_kda";
|
||||
}
|
||||
|
||||
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
|
||||
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
gated_delta_net_pipelines[key] = pipeline;
|
||||
return gated_delta_net_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
||||
|
||||
|
|
@ -1058,6 +1382,7 @@ class ggml_webgpu_shader_lib {
|
|||
.op = op,
|
||||
.is_unary = is_unary,
|
||||
.inplace = context.inplace,
|
||||
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
|
||||
};
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
|
|
@ -1088,6 +1413,29 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_inplace";
|
||||
}
|
||||
|
||||
if (op == GGML_OP_TRI) {
|
||||
switch (key.ttype) {
|
||||
case GGML_TRI_TYPE_LOWER:
|
||||
defines.push_back("TRI_TYPE_LOWER");
|
||||
variant += "_tri_type_lower";
|
||||
break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG:
|
||||
defines.push_back("TRI_TYPE_LOWER_DIAG");
|
||||
variant += "_tri_type_lower_diag";
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER:
|
||||
defines.push_back("TRI_TYPE_UPPER");
|
||||
variant += "_tri_type_upper";
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG:
|
||||
defines.push_back("TRI_TYPE_UPPER_DIAG");
|
||||
variant += "_tri_upper_diag";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
||||
|
|
|
|||
|
|
@ -366,7 +366,6 @@ struct webgpu_context_struct {
|
|||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
||||
|
||||
|
|
@ -509,50 +508,39 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &
|
|||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission> & subs,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
|
||||
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
|
||||
while (blocking_wait) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
||||
#endif
|
||||
subs.erase(subs.begin());
|
||||
}
|
||||
blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
|
||||
}
|
||||
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (block) {
|
||||
for (auto & sub : subs) {
|
||||
while (!sub.submit_done.completed) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus);
|
||||
}
|
||||
// Poll each submit future once and remove completed submissions.
|
||||
for (auto sub = subs.begin(); sub != subs.end();) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
||||
bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
|
||||
#endif
|
||||
}
|
||||
subs.clear();
|
||||
} else {
|
||||
// Poll each submit future once and remove completed submissions.
|
||||
for (auto sub = subs.begin(); sub != subs.end();) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
||||
if (sub->submit_done.completed && sub->profile_futures.empty()) {
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
||||
if (success && sub->profile_futures.empty()) {
|
||||
#else
|
||||
if (sub->submit_done.completed) {
|
||||
if (success) {
|
||||
#endif
|
||||
sub = subs.erase(sub);
|
||||
} else {
|
||||
++sub;
|
||||
}
|
||||
sub = subs.erase(sub);
|
||||
} else {
|
||||
++sub;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -892,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
|
|||
params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
|
||||
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
1u,
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
uint32_t binding_index = 0;
|
||||
if (!inplace) {
|
||||
entries.push_back({ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
|
||||
binding_index++;
|
||||
}
|
||||
entries.push_back({ .binding = binding_index,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||
entries.push_back({ .binding = binding_index + 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
|
|
@ -947,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
|
||||
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
token_tiles,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
|
||||
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * src3,
|
||||
ggml_tensor * src4,
|
||||
ggml_tensor * src5,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.src2 = src2,
|
||||
.src3 = src3,
|
||||
.src4 = src4,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
const uint32_t h = (uint32_t) src2->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t) src2->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t) src2->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) s_v);
|
||||
uint32_t scale_u32;
|
||||
memcpy(&scale_u32, &scale, sizeof(scale_u32));
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
h,
|
||||
n_tokens,
|
||||
n_seqs,
|
||||
s_v * h * n_tokens * n_seqs,
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
|
||||
|
||||
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
|
||||
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) (src2->ne[3] / src0->ne[3]),
|
||||
scale_u32,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(src3),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
|
||||
{ .binding = 4,
|
||||
.buffer = ggml_webgpu_tensor_buf(src4),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
|
||||
{ .binding = 5,
|
||||
.buffer = ggml_webgpu_tensor_buf(src5),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
|
||||
{ .binding = 6,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
|
||||
}
|
||||
|
||||
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
|
|
@ -1028,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
|
|
@ -1072,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
|
||||
uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
|
||||
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
|
||||
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
|
||||
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
|
@ -1609,8 +1866,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
|
|
@ -1641,8 +1898,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
|
||||
entries, ggml_nrows(src));
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
||||
|
|
@ -2181,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
return ggml_webgpu_cpy(ctx, src0, node);
|
||||
case GGML_OP_SET:
|
||||
return ggml_webgpu_set(ctx, src0, src1, node);
|
||||
case GGML_OP_SET_ROWS:
|
||||
return ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
@ -2203,7 +2469,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_REPEAT:
|
||||
return ggml_webgpu_repeat(ctx, src0, node);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_GLU:
|
||||
|
|
@ -2220,7 +2487,15 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_TRI:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_CONV:
|
||||
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
|
||||
case GGML_OP_PAD:
|
||||
return ggml_webgpu_pad(ctx, src0, node);
|
||||
case GGML_OP_ARGMAX:
|
||||
|
|
@ -2625,15 +2900,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
|
||||
webgpu_ctx->rms_norm_pipelines[0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
|
||||
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
|
|
@ -2918,7 +3184,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
|
||||
|
|
@ -2961,17 +3226,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|||
|
||||
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
/* .alloc_buffer = */
|
||||
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
||||
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */
|
||||
dev,
|
||||
/* .context = */
|
||||
NULL
|
||||
/* .context = */ NULL
|
||||
};
|
||||
|
||||
return &ggml_backend_webgpu_buffer_type;
|
||||
|
|
@ -3053,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
|
||||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
supports_op = src0->type == src1->type && src0->type == op->type &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
||||
|
|
@ -3130,6 +3398,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
|
|
@ -3192,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
|
||||
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
|
||||
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
|
||||
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src_q: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src_k: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src_v: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> src_g: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> src_beta: array<f32>;
|
||||
|
||||
@group(0) @binding(5)
|
||||
var<storage, read_write> src_state: array<f32>;
|
||||
|
||||
@group(0) @binding(6)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
h: u32,
|
||||
n_tokens: u32,
|
||||
n_seqs: u32,
|
||||
s_off: u32,
|
||||
|
||||
sq1: u32,
|
||||
sq2: u32,
|
||||
sq3: u32,
|
||||
|
||||
sv1: u32,
|
||||
sv2: u32,
|
||||
sv3: u32,
|
||||
|
||||
sb1: u32,
|
||||
sb2: u32,
|
||||
sb3: u32,
|
||||
|
||||
neq1: u32,
|
||||
rq3: u32,
|
||||
scale: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(7)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> sh_k: array<f32, S_V>;
|
||||
var<workgroup> sh_q: array<f32, S_V>;
|
||||
var<workgroup> sh_g: array<f32, S_V>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let head_id = workgroup_id.x;
|
||||
let seq_id = workgroup_id.y;
|
||||
let col = local_id.x;
|
||||
|
||||
let iq1 = head_id % params.neq1;
|
||||
let iq3 = seq_id / params.rq3;
|
||||
|
||||
let state_size = S_V * S_V;
|
||||
let state_base = (seq_id * params.h + head_id) * state_size;
|
||||
|
||||
var state: array<f32, S_V>;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = src_state[state_base + col * S_V + i];
|
||||
}
|
||||
|
||||
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
|
||||
|
||||
for (var t = 0u; t < params.n_tokens; t++) {
|
||||
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
|
||||
let k_off = q_off;
|
||||
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
|
||||
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
|
||||
|
||||
sh_q[col] = src_q[q_off + col];
|
||||
sh_k[col] = src_k[k_off + col];
|
||||
|
||||
#ifdef KDA
|
||||
let g_base = gb_off * S_V;
|
||||
sh_g[col] = exp(src_g[g_base + col]);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let v_val = src_v[v_off + col];
|
||||
let beta_val = src_beta[gb_off];
|
||||
|
||||
var kv_col = 0.0;
|
||||
var delta_col = 0.0;
|
||||
var attn_col = 0.0;
|
||||
|
||||
#ifdef KDA
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += (sh_g[i] * state[i]) * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#else
|
||||
let g_val = exp(src_g[gb_off]);
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += state[i] * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = g_val * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
dst[attn_off + col] = attn_col * params.scale;
|
||||
attn_off += S_V * params.h;
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[params.s_off + state_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
|
@ -640,6 +640,35 @@ var<uniform> params: Params;
|
|||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
#ifdef FLOAT_PARALLEL
|
||||
let blocks_per_row = params.ne0 / BLOCK_SIZE;
|
||||
let row_count = params.n_rows * params.ne2 * params.ne3;
|
||||
|
||||
if (gid.x >= blocks_per_row * row_count) {
|
||||
return;
|
||||
}
|
||||
|
||||
let block_idx = gid.x % blocks_per_row;
|
||||
var row_idx = gid.x / blocks_per_row;
|
||||
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
|
||||
|
||||
row_idx = row_idx % (params.ne2 * params.n_rows);
|
||||
let i_dst2 = row_idx / params.n_rows;
|
||||
let i_dst1 = row_idx % params.n_rows;
|
||||
|
||||
let i_idx2 = i_dst3 % params.idx2;
|
||||
let i_idx1 = i_dst2 % params.idx1;
|
||||
let i_idx0 = i_dst1;
|
||||
|
||||
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
|
||||
let idx_val = u32(idx[i_idx]);
|
||||
|
||||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
copy_elements(i_src_row, i_dst_row, block_idx);
|
||||
#else
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,11 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "inplace",
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
#ifdef INPLACE
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
dst[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
|
@ -25,23 +15,7 @@ var<storage, read_write> dst: array<f32>;
|
|||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
|
|
@ -68,12 +42,9 @@ struct Params {
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
|
|
@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
|
||||
var sum = 0.0f;
|
||||
var col = lid.x;
|
||||
|
|
@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
var offset: u32 = WG_SIZE / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
|
|
@ -110,14 +81,18 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef RMS_NORM
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif defined(L2_NORM)
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
#ifdef TYPE_I32
|
||||
#define TYPE i32
|
||||
#else
|
||||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
#ifndef INPLACE
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<TYPE>;
|
||||
#define SRC1_BINDING 1
|
||||
#else
|
||||
#define SRC1_BINDING 0
|
||||
#endif
|
||||
|
||||
#define DST_BINDING SRC1_BINDING + 1
|
||||
#define PARAMS_BINDING SRC1_BINDING + 2
|
||||
|
||||
@group(0) @binding(SRC1_BINDING)
|
||||
var<storage, read_write> src1: array<TYPE>;
|
||||
|
||||
@group(0) @binding(DST_BINDING)
|
||||
var<storage, read_write> dst: array<TYPE>;
|
||||
|
||||
struct Params {
|
||||
ne: u32,
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_view: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst10: u32,
|
||||
stride_dst11: u32,
|
||||
stride_dst12: u32,
|
||||
stride_dst13: u32,
|
||||
|
||||
src1_ne0: u32,
|
||||
src1_ne1: u32,
|
||||
src1_ne2: u32,
|
||||
src1_ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn decode_src1_coords(idx: u32) -> vec4<u32> {
|
||||
var i = idx;
|
||||
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
|
||||
let i3 = i / plane;
|
||||
i = i % plane;
|
||||
let row = params.src1_ne1 * params.src1_ne0;
|
||||
let i2 = i / row;
|
||||
i = i % row;
|
||||
let i1 = i / params.src1_ne0;
|
||||
let i0 = i % params.src1_ne0;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn decode_view_coords(rel: u32) -> vec4<u32> {
|
||||
let i3 = rel / params.stride_dst13;
|
||||
let rem3 = rel % params.stride_dst13;
|
||||
let i2 = rem3 / params.stride_dst12;
|
||||
let rem2 = rem3 % params.stride_dst12;
|
||||
let i1 = rem2 / params.stride_dst11;
|
||||
let i0 = rem2 % params.stride_dst11;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
|
||||
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
|
||||
}
|
||||
|
||||
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
|
||||
coords.z * params.stride_src12 + coords.w * params.stride_src13;
|
||||
}
|
||||
|
||||
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
|
||||
return view_rel_from_coords(coords) == rel;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef INPLACE
|
||||
let coords = decode_src1_coords(gid.x);
|
||||
|
||||
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
|
||||
let dst_idx = params.offset_view + view_rel_from_coords(coords);
|
||||
|
||||
dst[dst_idx] = src1[src1_idx];
|
||||
#else
|
||||
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
|
||||
let coords = decode_view_coords(rel);
|
||||
|
||||
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
|
||||
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
|
||||
} else {
|
||||
dst[gid.x] = src0[params.offset_src0 + gid.x];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src00: u32,
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
k: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shA: array<f32, BATCH_N * N>;
|
||||
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
|
||||
|
||||
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src0 +
|
||||
col * params.stride_src00 +
|
||||
row * params.stride_src01 +
|
||||
i2 * params.stride_src02 +
|
||||
i3 * params.stride_src03;
|
||||
}
|
||||
|
||||
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src1 +
|
||||
col * params.stride_src10 +
|
||||
row * params.stride_src11 +
|
||||
i2 * params.stride_src12 +
|
||||
i3 * params.stride_src13;
|
||||
}
|
||||
|
||||
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_dst +
|
||||
col * params.stride_dst0 +
|
||||
row * params.stride_dst1 +
|
||||
i2 * params.stride_dst2 +
|
||||
i3 * params.stride_dst3;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let batch = workgroup_id.y;
|
||||
let col = workgroup_id.x * WG_SIZE + local_id.x;
|
||||
let i3 = batch / params.ne2;
|
||||
let i2 = batch % params.ne2;
|
||||
let active_lane = local_id.x < K_TILE;
|
||||
let active_col = active_lane && col < params.k;
|
||||
|
||||
var X: array<f32, N>;
|
||||
|
||||
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
|
||||
let cur_n = min(BATCH_N, N - row_base);
|
||||
|
||||
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
|
||||
let tile_row = i / N;
|
||||
let tile_col = i % N;
|
||||
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
|
||||
let tile_row = i / K_TILE;
|
||||
let tile_col = i % K_TILE;
|
||||
let global_col = workgroup_id.x * WG_SIZE + tile_col;
|
||||
let sh_idx = tile_row * K_TILE + tile_col;
|
||||
|
||||
if (global_col < params.k) {
|
||||
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
|
||||
} else {
|
||||
shB[sh_idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (active_col) {
|
||||
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
|
||||
let r = row_base + row_offset;
|
||||
var b = shB[row_offset * K_TILE + local_id.x];
|
||||
let a_row = row_offset * N;
|
||||
|
||||
for (var t = 0u; t < r; t++) {
|
||||
b -= shA[a_row + t] * X[t];
|
||||
}
|
||||
|
||||
let x = b / shA[a_row + r];
|
||||
X[r] = x;
|
||||
dst[dst_idx(r, col, i2, i3)] = x;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src11: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
|
||||
nc: u32,
|
||||
nr: u32,
|
||||
n_t: u32,
|
||||
n_s: u32,
|
||||
token_tiles: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i1 = gid.x;
|
||||
let tile_y = gid.y / TOKENS_PER_WG;
|
||||
let local_token = gid.y % TOKENS_PER_WG;
|
||||
let i3 = tile_y / params.token_tiles;
|
||||
let token_tile = tile_y % params.token_tiles;
|
||||
let i2 = token_tile * TOKENS_PER_WG + local_token;
|
||||
|
||||
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
|
||||
let src1_base = params.offset_src1 + i1 * params.stride_src11;
|
||||
|
||||
var sum = 0.0;
|
||||
|
||||
#ifdef VECTORIZED
|
||||
sum =
|
||||
src0[src0_base + 0u] * src1[src1_base + 0u] +
|
||||
src0[src0_base + 1u] * src1[src1_base + 1u] +
|
||||
src0[src0_base + 2u] * src1[src1_base + 2u] +
|
||||
src0[src0_base + 3u] * src1[src1_base + 3u];
|
||||
#else
|
||||
for (var i0 = 0u; i0 < params.nc; i0++) {
|
||||
sum += src0[src0_base + i0] * src1[src1_base + i0];
|
||||
}
|
||||
#endif
|
||||
|
||||
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@ enable f16;
|
|||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<TYPE>;
|
||||
|
||||
|
|
@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
let ne2 = params.ne2;
|
||||
#ifdef DIAG
|
||||
let ne1 = params.ne0;
|
||||
#else
|
||||
let ne1 = params.ne1;
|
||||
#endif
|
||||
let ne0 = params.ne0;
|
||||
|
||||
let i3 = i / (ne2 * ne1 * ne0);
|
||||
i = i % (ne2 * ne1 * ne0);
|
||||
let i2 = i / (ne1 * ne0);
|
||||
i = i % (ne1 * ne0);
|
||||
let i1 = i / ne0;
|
||||
let i0 = i % ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
|
@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
|
||||
let res = TYPE(res_f32);
|
||||
#endif
|
||||
#ifdef DIAG
|
||||
let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
|
||||
#endif
|
||||
#ifdef TRI
|
||||
#ifdef TRI_TYPE_LOWER
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
|
||||
#elif TRI_TYPE_LOWER_DIAG
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
|
||||
#elif TRI_TYPE_UPPER
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
|
||||
#elif TRI_TYPE_UPPER_DIAG
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
src[params.offset_src + src_idx] = res;
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue