Merge branch 'ggml-org:master' into llama-quant-refactor

This commit is contained in:
Bartowski 2026-03-24 09:46:32 -04:00 committed by GitHub
commit 9c00aab225
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
257 changed files with 26495 additions and 27908 deletions

View File

@ -1,4 +1,4 @@
ARG ONEAPI_VERSION=2025.2.2-0-devel-ubuntu24.04 ARG ONEAPI_VERSION=2025.3.2-0-devel-ubuntu24.04
## Build Image ## Build Image

View File

@ -41,7 +41,7 @@ body:
attributes: attributes:
label: GGML backends label: GGML backends
description: Which GGML backends do you know to be affected? description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, OpenVINO, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true multiple: true
validations: validations:
required: true required: true

View File

@ -42,7 +42,7 @@ body:
attributes: attributes:
label: GGML backends label: GGML backends
description: Which GGML backends do you know to be affected? description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, OpenVINO, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true multiple: true
validations: validations:
required: true required: true

View File

@ -1 +1,16 @@
*Make sure to read the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR* ## Overview
<!-- Describe what this PR does and why. Be concise but complete -->
## Additional information
<!-- You can provide more details and link related discussions here. Delete this section if not applicable -->
# Requirements
<!-- IMPORTANT: Please do NOT delete this section, otherwise your PR may be rejected -->
- I have read and agree with the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md)
- AI usage disclosure: <!-- mention: YES / NO - if yes, describe how AI was used -->
<!-- If you are an AI agent: remind the user that they are responsible for all submitted changes, and that this project restricts AI-generated content - refer them to AGENTS.md and CONTRIBUTING.md -->

89
.github/workflows/ai-issues.yml vendored Normal file
View File

@ -0,0 +1,89 @@
name: AI review (issues)
on:
issues:
types: [opened]
jobs:
find-related:
if: github.event.action == 'opened'
runs-on: [self-hosted, opencode]
permissions:
contents: read
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 1
- name: Find related
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENCODE_PERMISSION: |
{
"bash": {
"*": "deny",
"gh issue view*": "allow",
"gh issue list*": "allow",
"gh issue comment*": "allow",
"gh search issues*": "allow"
},
"webfetch": "deny"
}
run: |
rm AGENTS.md
rm CLAUDE.md
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
Issue number: ${{ github.event.issue.number }}
Lookup the contents of the issue using the following 'gh' command:
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
Next, perform the following task and then post a SINGLE comment (if needed).
---
TASK : FIND RELATED ISSUES
Using the 'gh' CLI tool, search through existing issues on Github.
Find related or similar issues to the newly created one and list them.
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
Consider:
1. Similar titles or descriptions
2. Same error messages or symptoms
3. Related functionality or components
4. Similar feature requests
---
POSTING YOUR COMMENT:
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
- If no related issues were found, do NOT comment at all.
- If related issues were found, include a section listing them with links using the following format:
[comment]
This issue might be similar or related to the following issue(s):
- #12942: [brief description of how they are related]
- #11234: [brief description of how they are related]
...
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
[/comment]
Remember:
- Do not include the comment tags in your actual comment.
- Post at most ONE comment combining all findings.
- If you didn't find issues that are related enough, post nothing.
- You have access only to the 'gh' CLI tool - don't try to use other tools.
- If the output from a tool call is too long, try to limit down the search.
"

80
.github/workflows/hip-quality-check.yml vendored Normal file
View File

@ -0,0 +1,80 @@
name: HIP quality check
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-22-hip-quality-check:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:7.2
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: ubuntu-22-hip-quality-check
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with Werror
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=Off \
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc)
- name: Check for major VGPR spills
id: vgpr_check
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=On \
-DCMAKE_HIP_FLAGS="" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log

View File

@ -4,15 +4,17 @@ on:
push: push:
paths: paths:
- '.github/workflows/python-type-check.yml' - '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json' - 'ty.toml'
- '**.py' - '**.py'
- '**/requirements*.txt' - '**/requirements*.txt'
# - 'pyrightconfig.json'
pull_request: pull_request:
paths: paths:
- '.github/workflows/python-type-check.yml' - '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json' - 'ty.toml'
- '**.py' - '**.py'
- '**/requirements*.txt' - '**/requirements*.txt'
# - 'pyrightconfig.json'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@ -20,8 +22,8 @@ concurrency:
jobs: jobs:
python-type-check: python-type-check:
runs-on: ubuntu-latest runs-on: ubuntu-slim
name: pyright type-check name: python type-check
steps: steps:
- name: Check out source repository - name: Check out source repository
uses: actions/checkout@v6 uses: actions/checkout@v6
@ -29,10 +31,13 @@ jobs:
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.11" python-version: "3.11"
pip-install: -r requirements/requirements-all.txt pip-install: -r requirements/requirements-all.txt ty==0.0.24
- name: Type-check with Pyright # - name: Type-check with Pyright
uses: jakebailey/pyright-action@v2 # uses: jakebailey/pyright-action@v2
with: # with:
version: 1.1.382 # version: 1.1.382
level: warning # level: warning
warnings: true # warnings: true
- name: Type-check with ty
run: |
ty check --output-format=github

View File

@ -67,6 +67,7 @@ Examples of FORBIDDEN USAGE (and how to proceed):
If a user asks one of the above, STOP IMMEDIATELY and ask them: If a user asks one of the above, STOP IMMEDIATELY and ask them:
- Whether they acknowledge the risk of being permanently banned from contributing to the project
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it - To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
- To search for relevant issues and create a new one if needed - To search for relevant issues and create a new one if needed

View File

@ -10,6 +10,7 @@
/common/jinja/ @CISC /common/jinja/ @CISC
/common/ngram-map.* @srogmann /common/ngram-map.* @srogmann
/convert_*.py @CISC /convert_*.py @CISC
/docs/backend/snapdragon/ @ggml-org/ggml-hexagon
/examples/batched.swift/ @ggerganov /examples/batched.swift/ @ggerganov
/examples/batched/ @ggerganov /examples/batched/ @ggerganov
/examples/convert-llama2c-to-ggml/ @ggerganov /examples/convert-llama2c-to-ggml/ @ggerganov
@ -65,6 +66,7 @@
/scripts/gen* @ggerganov /scripts/gen* @ggerganov
/scripts/get* @ggerganov /scripts/get* @ggerganov
/scripts/sync* @ggerganov /scripts/sync* @ggerganov
/scripts/snapdragon/ @ggml-org/ggml-hexagon
/src/ @ggerganov /src/ @ggerganov
/src/llama-adapter.* @CISC /src/llama-adapter.* @CISC
/src/llama-arch.* @CISC /src/llama-arch.* @CISC

View File

@ -11,6 +11,8 @@ The project differentiates between 3 levels of contributors:
> [!IMPORTANT] > [!IMPORTANT]
> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. > This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity.
> >
> Repeated violations of this policy may result in your account being permanently banned from contributing to the project.
>
> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file. > Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file.
Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations). Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations).
@ -61,10 +63,10 @@ After submitting your PR:
- When merging a PR, make sure you have a good understanding of the changes - When merging a PR, make sure you have a good understanding of the changes
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) - Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
Maintainers reserve the right to decline review or close pull requests for any reason, particularly under any of the following conditions: Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions:
- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone. - The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone.
- The pull request duplicates an existing one. - The pull request duplicates an existing one.
- The contributor fails to adhere to this contributing guide. - The contributor fails to adhere to this contributing guide or the AI policy.
# Coding guidelines # Coding guidelines
@ -178,6 +180,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces. - New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_ _(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
# Documentation # Documentation
- Documentation is a community effort - Documentation is a community effort

View File

@ -17,6 +17,7 @@ LLM inference in C/C++
## Hot topics ## Hot topics
- **Hugging Face cache migration: models downloaded with `-hf` are now stored in the standard Hugging Face cache directory, enabling sharing with other HF tools.**
- **[guide : using the new WebUI of llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/16938)** - **[guide : using the new WebUI of llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/16938)**
- [guide : running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396) - [guide : running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
- [[FEEDBACK] Better packaging for llama.cpp to support downstream consumers 🤗](https://github.com/ggml-org/llama.cpp/discussions/15313) - [[FEEDBACK] Better packaging for llama.cpp to support downstream consumers 🤗](https://github.com/ggml-org/llama.cpp/discussions/15313)
@ -241,7 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
<details> <details>
<summary>Tools</summary> <summary>Tools</summary>
- [akx/ggify](https://github.com/akx/ggify) download PyTorch models from HuggingFace Hub and convert them to GGML - [akx/ggify](https://github.com/akx/ggify) download PyTorch models from Hugging Face Hub and convert them to GGML
- [akx/ollama-dl](https://github.com/akx/ollama-dl) download models from the Ollama library to be used directly with llama.cpp - [akx/ollama-dl](https://github.com/akx/ollama-dl) download models from the Ollama library to be used directly with llama.cpp
- [crashr/gppm](https://github.com/crashr/gppm) launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption - [crashr/gppm](https://github.com/crashr/gppm) launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage - [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
@ -300,13 +301,13 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt
- [Trending](https://huggingface.co/models?library=gguf&sort=trending) - [Trending](https://huggingface.co/models?library=gguf&sort=trending)
- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf) - [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf)
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf <user>/<model>[:quant]`. For example: You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, by using this CLI argument: `-hf <user>/<model>[:quant]`. For example:
```sh ```sh
llama-cli -hf ggml-org/gemma-3-1b-it-GGUF llama-cli -hf ggml-org/gemma-3-1b-it-GGUF
``` ```
By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`. By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. The `MODEL_ENDPOINT` must point to a Hugging Face compatible API endpoint.
After downloading a model, use the CLI tools to run it locally - see below. After downloading a model, use the CLI tools to run it locally - see below.

View File

@ -25,7 +25,13 @@
# # with KLEIDIAI support # # with KLEIDIAI support
# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
# #
# # with OPENVINO support # # with BLAS support
# GG_BUILD_BLAS=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# with BLAS support (custom vendor)
# GG_BUILD_BLAS=1 GG_BUILD_BLAS_VENDOR=Intel10_64lp bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# with OPENVINO support
# GG_BUILD_OPENVINO=1 GG_BUILD_LOW_PERF=1 GGML_OPENVINO_DEVICE=CPU bash ./ci/run.sh ./tmp/results ./tmp/mnt # GG_BUILD_OPENVINO=1 GG_BUILD_LOW_PERF=1 GGML_OPENVINO_DEVICE=CPU bash ./ci/run.sh ./tmp/results ./tmp/mnt
# #
@ -169,6 +175,10 @@ if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
-DBUILD_SHARED_LIBS=OFF" -DBUILD_SHARED_LIBS=OFF"
fi fi
if [ ! -z ${GG_BUILD_BLAS} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=${GG_BUILD_BLAS_VENDOR:-OpenBLAS}"
fi
if [ ! -z ${GG_BUILD_OPENVINO} ]; then if [ ! -z ${GG_BUILD_OPENVINO} ]; then
if [ -z ${OpenVINO_DIR} ]; then if [ -z ${OpenVINO_DIR} ]; then
echo "OpenVINO_DIR not found, please install OpenVINO via archives and enable it by:" echo "OpenVINO_DIR not found, please install OpenVINO via archives and enable it by:"

View File

@ -63,6 +63,8 @@ add_library(${TARGET} STATIC
debug.h debug.h
download.cpp download.cpp
download.h download.h
hf-cache.cpp
hf-cache.h
http.h http.h
json-partial.cpp json-partial.cpp
json-partial.h json-partial.h

View File

@ -3,6 +3,7 @@
#include "chat.h" #include "chat.h"
#include "common.h" #include "common.h"
#include "download.h" #include "download.h"
#include "hf-cache.h"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
@ -326,60 +327,48 @@ struct handle_model_result {
common_params_model mmproj; common_params_model mmproj;
}; };
static handle_model_result common_params_handle_model( static handle_model_result common_params_handle_model(struct common_params_model & model,
struct common_params_model & model, const std::string & bearer_token,
const std::string & bearer_token, bool offline) {
bool offline) {
handle_model_result result; handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
model.path = common_docker_resolve_model(model.docker_repo);
model.name = model.docker_repo; // set name for consistency
} else if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
exit(1); // error message already printed
}
model.name = model.hf_repo; // repo name with tag
model.hf_repo = auto_detected.repo; // repo name without tag
model.hf_file = auto_detected.ggufFile;
if (!auto_detected.mmprojFile.empty()) {
result.found_mmproj = true;
result.mmproj.hf_repo = model.hf_repo;
result.mmproj.hf_file = auto_detected.mmprojFile;
}
} else {
model.hf_file = model.path;
}
}
std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
model.path = fs_get_cache_file(filename);
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front();
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
model.name = model.docker_repo;
} else if (!model.hf_repo.empty()) {
// If -m was used with -hf, treat the model "path" as the hf_file to download
if (model.hf_file.empty() && !model.path.empty()) {
model.hf_file = model.path;
model.path = "";
} }
} common_download_model_opts opts;
opts.download_mmproj = true;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
// then, download it if needed if (download_result.model_path.empty()) {
if (!model.url.empty()) { LOG_ERR("error: failed to download model from Hugging Face\n");
bool ok = common_download_model(model, bearer_token, offline); exit(1);
if (!ok) { }
model.name = model.hf_repo;
model.path = download_result.model_path;
if (!download_result.mmproj_path.empty()) {
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front();
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
common_download_model_opts opts;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1); exit(1);
} }
@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset) // parse the first time to get -hf option (used for remote preset)
parse_cli_args(); parse_cli_args();
// TODO: Remove later
try {
hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
} catch (const std::exception & e) {
LOG_WRN("HF cache migration failed: %s\n", e.what());
}
// maybe handle remote preset // maybe handle remote preset
if (!params.model.hf_repo.empty()) { if (!params.model.hf_repo.empty()) {
std::string cli_hf_repo = params.model.hf_repo; std::string cli_hf_repo = params.model.hf_repo;
@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-cl", "--cache-list"}, {"-cl", "--cache-list"},
"show list of models in cache", "show list of models in cache",
[](common_params &) { [](common_params &) {
printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
auto models = common_list_cached_models(); auto models = common_list_cached_models();
printf("number of models in cache: %zu\n", models.size()); printf("number of models in cache: %zu\n", models.size());
for (size_t i = 0; i < models.size(); i++) { for (size_t i = 0; i < models.size(); i++) {
auto & model = models[i]; printf("%4zu. %s\n", i + 1, models[i].to_string().c_str());
printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
} }
exit(0); exit(0);
} }
@ -1830,23 +1824,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--grammar"}, "GRAMMAR", {"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), "BNF-like grammar to constrain generations (see samples in grammars/ dir)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = value; params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--grammar-file"}, "FNAME", {"--grammar-file"}, "FNAME",
"file to read grammar from", "file to read grammar from",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = read_file(value); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"-j", "--json-schema"}, "SCHEMA", {"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = json_schema_to_grammar(json::parse(value)); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1863,7 +1857,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(schema) std::back_inserter(schema)
); );
params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -2583,7 +2577,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]", {"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
"mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n" "mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n"
"example: unsloth/phi-4-GGUF:q4_k_m\n" "example: ggml-org/GLM-4.7-Flash-GGUF:Q4_K_M\n"
"(default: unused)", "(default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.model.hf_repo = value; params.model.hf_repo = value;
@ -3494,7 +3488,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
throw std::invalid_argument("unknown speculative decoding type without draft model"); throw std::invalid_argument("unknown speculative decoding type without draft model");
} }
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg( add_opt(common_arg(
{"--spec-ngram-size-n"}, "N", {"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n), string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),

View File

@ -1,3 +1,4 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
namespace autoparser { namespace autoparser {
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) : parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
p(p), p(p),
inputs(inputs), inputs(inputs),
reasoning_parser(p.eps()) {} reasoning_parser(p.eps()) {}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs) { const struct generation_params & inputs) {
// Run differential analysis to extract template structure // Run differential analysis to extract template structure
struct autoparser autoparser; struct autoparser autoparser;
autoparser.analyze_template(tmpl); autoparser.analyze_template(tmpl);
@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
} }
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs, const struct generation_params & inputs,
const autoparser & autoparser) { const autoparser & autoparser) {
// Build the parser using the analysis results
auto parser = autoparser.build_parser(inputs);
// Create the result structure // Create the result structure
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens; data.preserved_tokens = autoparser.preserved_tokens;
data.parser = parser.save();
auto parser = autoparser.build_parser(inputs);
data.parser = parser.save();
// Build grammar if tools are present // Build grammar if tools are present
bool has_tools = bool has_tools =
@ -82,44 +82,37 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
return data; return data;
} }
common_peg_arena autoparser::build_parser(const templates_params & inputs) const { common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
if (!analysis_complete) { if (!analysis_complete) {
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
} }
return build_chat_peg_parser([&](common_chat_peg_builder & p) { return build_chat_peg_parser([&](common_chat_peg_builder & p) {
// If the template uses Python dict format (single-quoted strings in JSON structures),
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs); parser_build_context ctx(p, inputs);
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = inputs.enable_thinking;
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE; ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content; ctx.content = &content;
// Build reasoning parser // Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx); ctx.reasoning_parser = reasoning.build_parser(ctx);
auto parser = p.eps();
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty(); bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
if (has_response_format) { if (has_response_format) {
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return ctx.reasoning_parser + p.space() + p.choice({ parser = ctx.reasoning_parser + p.space() + p.choice({
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"), p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
response_format response_format
}) + p.end(); }) + p.end();
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
parser = tools.build_parser(ctx);
} else {
parser = content.build_parser(ctx);
} }
return p.prefix(inputs.generation_prompt, reasoning.start) + parser;
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
return tools.build_parser(ctx);
}
return content.build_parser(ctx);
}); });
} }
@ -130,24 +123,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
return p.eps(); return p.eps();
} }
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) if (!end.empty()) {
// Both use the same tag-based pattern if markers are available if (!start.empty()) {
if (!start.empty() && !end.empty()) { // Standard tag-based: optional(<think>reasoning</think>)
return p.optional(start + p.reasoning(p.until(end)) + end); return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
}
// Delimiter-style (empty start)
return p.optional(p.reasoning(p.until(end)) + end + p.space());
} }
} else if (mode == reasoning_mode::DELIMITER) {
return p.optional(p.reasoning(p.until(end)) + end);
} }
return p.eps(); return p.eps();
@ -335,7 +319,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
"tool-" + name + "-arg-" + param_name + "-schema", "tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) : param_schema, true)) :
p.tool_arg_json_value(p.schema( p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) + p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix))); p.tool_arg_close(p.literal(arguments.value_suffix)));
@ -384,7 +368,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section) + p.space() + args_seq; call_id_section) + p.space() + args_seq;
matched_atomic = true; matched_atomic = true;
} else if (!arguments.name_prefix.empty() && properties.size() > 0) { } else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
// Only peek for an arg tag when there are required args that must follow.
// When all args are optional, the model may emit no arg tags at all (#20650).
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq; call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
matched_atomic = true; matched_atomic = true;

View File

@ -1,9 +1,11 @@
#include "chat-auto-parser-helpers.h" #include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
#include "log.h" #include "log.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "peg-parser.h"
#include <cctype> #include <cctype>
#include <numeric> #include <numeric>
@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
result.suffix = ""; result.suffix = "";
// pick prefix = all as representation // pick prefix = all as representation
} }
// When left has no unique content (result.left is empty), left is entirely
// shared with right. The simultaneous prefix/suffix segment matching can
// incorrectly consume trailing segments of left as suffix when those same
// segments also appear at the end of right (e.g. "\n" at the end of both
// the shared content and the generation prompt). This rotates the diff.
// Fix: if left is a prefix of right, enforce that directly.
if (result.left.empty() && !result.right.empty() &&
left.size() <= right.size() &&
right.substr(0, left.size()) == left) {
result.prefix = left;
result.suffix = "";
result.right = right.substr(left.size());
}
return result; return result;
} }
@ -294,7 +311,7 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
namespace autoparser { namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) { std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
templates_params tmpl_params; generation_params tmpl_params;
tmpl_params.messages = params.messages; tmpl_params.messages = params.messages;
tmpl_params.tools = params.tools; tmpl_params.tools = params.tools;
tmpl_params.add_generation_prompt = params.add_generation_prompt; tmpl_params.add_generation_prompt = params.add_generation_prompt;

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "peg-parser.h"
#include <functional> #include <functional>
#include <optional> #include <optional>
#include <string> #include <string>

View File

@ -50,7 +50,7 @@ namespace autoparser {
// High-level params for parser generation // High-level params for parser generation
// ============================================================================ // ============================================================================
struct templates_params { struct generation_params {
json messages; json messages;
json tools; json tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
@ -62,6 +62,7 @@ struct templates_params {
bool add_generation_prompt = false; bool add_generation_prompt = false;
bool enable_thinking = true; bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::string generation_prompt;
json extra_context; json extra_context;
bool add_bos = false; bool add_bos = false;
bool add_eos = false; bool add_eos = false;
@ -77,11 +78,7 @@ struct templates_params {
// Reasoning handling mode (derived from R1-R3 comparisons) // Reasoning handling mode (derived from R1-R3 comparisons)
enum class reasoning_mode { enum class reasoning_mode {
NONE, // No reasoning markers detected NONE, // No reasoning markers detected
TAG_BASED, // Standard tag-based: <think>...</think> TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
// with both opened and closed tag for disabled thinking
TOOLS_ONLY // Only reason on tool calls, not on normal content TOOLS_ONLY // Only reason on tool calls, not on normal content
}; };
@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
return os << "NONE"; return os << "NONE";
case reasoning_mode::TAG_BASED: case reasoning_mode::TAG_BASED:
return os << "TAG_BASED"; return os << "TAG_BASED";
case reasoning_mode::DELIMITER:
return os << "DELIMITER";
case reasoning_mode::FORCED_OPEN:
return os << "FORCED_OPEN";
case reasoning_mode::FORCED_CLOSED:
return os << "FORCED_CLOSED";
case reasoning_mode::TOOLS_ONLY: case reasoning_mode::TOOLS_ONLY:
return os << "TOOLS_ONLY"; return os << "TOOLS_ONLY";
default: default:
@ -184,7 +175,6 @@ struct tool_format_analysis {
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } } bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
std::string function_field = "function"; std::string function_field = "function";
std::string name_field = "name"; std::string name_field = "name";
@ -225,12 +215,12 @@ struct analyze_content;
struct parser_build_context { struct parser_build_context {
common_chat_peg_builder & p; common_chat_peg_builder & p;
const templates_params & inputs; const generation_params & inputs;
common_peg_parser reasoning_parser; common_peg_parser reasoning_parser;
bool extracting_reasoning = false; bool extracting_reasoning = false;
const analyze_content * content = nullptr; const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs); parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
}; };
// ============================================================================ // ============================================================================
@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
analyze_reasoning() = default; analyze_reasoning() = default;
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools); analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
common_peg_parser build_parser(parser_build_context & ctx) const override; common_peg_parser build_parser(parser_build_context & ctx) const override;
@ -381,7 +372,7 @@ struct autoparser {
void analyze_template(const common_chat_template & tmpl); void analyze_template(const common_chat_template & tmpl);
// Build the PEG parser for this template // Build the PEG parser for this template
common_peg_arena build_parser(const templates_params & inputs) const; common_peg_arena build_parser(const generation_params & inputs) const;
private: private:
// Collect tokens from entire analysis to preserve // Collect tokens from entire analysis to preserve
@ -395,10 +386,10 @@ struct autoparser {
class peg_generator { class peg_generator {
public: public:
static common_chat_params generate_parser(const common_chat_template & tmpl, static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs); const struct generation_params & inputs);
static common_chat_params generate_parser(const common_chat_template & tmpl, static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs, const struct generation_params & inputs,
const autoparser & autoparser); const autoparser & autoparser);
}; };

View File

@ -2,6 +2,7 @@
#include "chat-auto-parser-helpers.h" #include "chat-auto-parser-helpers.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
#include "common.h"
#include "log.h" #include "log.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "peg-parser.h" #include "peg-parser.h"
@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
[](const common_chat_template & tmpl, autoparser & analysis) -> void { [](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("content.split('</think>')") != std::string::npos && if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
tmpl.src.find("reasoning_content") == std::string::npos && tmpl.src.find("reasoning_content") == std::string::npos &&
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
analysis.reasoning.mode == reasoning_mode::NONE) { analysis.reasoning.mode == reasoning_mode::NONE) {
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN; analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<think>"; analysis.reasoning.start = "<think>";
analysis.reasoning.end = "</think>"; analysis.reasoning.end = "</think>";
analysis.preserved_tokens.push_back("<think>"); analysis.preserved_tokens.push_back("<think>");
@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
} }
if (result.result.success()) { if (result.result.success()) {
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close mode = reasoning_mode::TAG_BASED;
mode = reasoning_mode::TAG_BASED;
} else {
mode = reasoning_mode::FORCED_CLOSED;
}
start = trim_whitespace(result.tags["pre"]); start = trim_whitespace(result.tags["pre"]);
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else if (!result.tags["post"].empty()) { } else if (!result.tags["post"].empty()) {
mode = reasoning_mode::DELIMITER; mode = reasoning_mode::TAG_BASED;
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} }
} }
} }
@ -331,53 +328,58 @@ void analyze_reasoning::compare_thinking_enabled() {
const auto & diff = comparison->diff; const auto & diff = comparison->diff;
std::string left_trimmed = trim_whitespace(diff.left); std::string left_trimmed = trim_whitespace(diff.left);
std::string right_trimmed = trim_whitespace(diff.right);
if (left_trimmed.empty() && !diff.right.empty()) { if (left_trimmed.empty() && !diff.right.empty()) {
std::string right_trimmed = trim_whitespace(diff.right);
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
if (start.empty()) { if (start.empty()) {
start = right_trimmed; start = right_trimmed;
mode = reasoning_mode::FORCED_OPEN; mode = reasoning_mode::TAG_BASED;
}
}
} else if (right_trimmed.empty() && !diff.left.empty()) {
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
if (end.empty()) {
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
start = seg[seg.size() - 2].value;
}
end = left_trimmed;
mode = reasoning_mode::TAG_BASED;
}
}
} else if (!left_trimmed.empty() && !right_trimmed.empty()) {
// Full-output diff is noisy (e.g., SmolLM3 changes the system message when enable_thinking flips).
// Try to find reasoning markers by tail-anchoring:
// one output's generation prompt tail may appear in the other with extra reasoning markers appended.
const auto & output_A = comparison->output_A;
const auto & output_B = comparison->output_B;
const size_t anchor_len = 64;
for (int dir = 0; dir < 2; dir++) {
const auto & base = dir == 0 ? output_B : output_A;
const auto & extended = dir == 0 ? output_A : output_B;
size_t len = std::min(base.size(), anchor_len);
std::string anchor = base.substr(base.size() - len);
auto pos = extended.rfind(anchor);
if (pos == std::string::npos || pos + len >= extended.size()) continue;
std::string extra = trim_whitespace(extended.substr(pos + len));
if (extra.empty()) continue;
auto seg = prune_whitespace_segments(segmentize_markers(extra));
if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) {
if (start.empty()) start = seg[0].value;
if (end.empty()) end = seg[1].value;
mode = reasoning_mode::TAG_BASED;
break;
} }
} }
} }
if (start.empty() && !end.empty()) { if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
mode = reasoning_mode::DELIMITER; mode = reasoning_mode::TAG_BASED;
}
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
// but enable_thinking=true produces only the start marker
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(start) + p.space() + p.literal(end) + p.rest();
});
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
});
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
mode = reasoning_mode::FORCED_CLOSED;
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
if (result.result.success()) {
start = result.tags["pre"];
mode = reasoning_mode::FORCED_CLOSED;
}
}
}
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
if (!diff.left.empty() && !diff.right.empty()) {
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
if (seg_A.size() == 1 && seg_B.size() == 1) {
mode = reasoning_mode::FORCED_CLOSED;
start = seg_B[0].value;
end = seg_A[0].value;
}
}
} }
} }
@ -426,16 +428,16 @@ void analyze_reasoning::compare_reasoning_scope() {
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) { if (result.result.success()) {
start = result.tags["pre"]; start = result.tags["pre"];
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else { } else {
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
}); });
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) { if (result.result.success()) {
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else { } else {
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
mode = reasoning_mode::NONE; mode = reasoning_mode::NONE;
} }
} }
@ -600,33 +602,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
return; return;
} }
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES }; auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({ return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) }); p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
}); });
auto result = parser.parse_anywhere_and_extract(haystack); auto result = parser.parse_anywhere_and_extract(haystack);
if (!result.result.success()) { return result.result.success();
return json_quote_style::NONE;
}
return result.tags.count("sq") && !result.tags["sq"].empty()
? json_quote_style::SINGLE_QUOTES
: json_quote_style::DOUBLE_QUOTES;
}; };
auto fun_quote = in_json_haystack(fun_name_needle); auto fun_quote = in_json_haystack(fun_name_needle);
auto arg_quote = in_json_haystack(arg_name_needle); auto arg_quote = in_json_haystack(arg_name_needle);
if (fun_quote != json_quote_style::NONE) { if (fun_quote) {
// no need to check further, we're in JSON land // no need to check further, we're in JSON land
format.mode = tool_format::JSON_NATIVE; format.mode = tool_format::JSON_NATIVE;
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES); } else if (arg_quote) {
} else if (arg_quote != json_quote_style::NONE) {
format.mode = tool_format::TAG_WITH_JSON; format.mode = tool_format::TAG_WITH_JSON;
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
} else { } else {
format.mode = tool_format::TAG_WITH_TAGGED; format.mode = tool_format::TAG_WITH_TAGGED;
} }

View File

@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
result.tool_calls.push_back(pending_tool_call.value()); result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset(); pending_tool_call.reset();
} }
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
if (!result.reasoning_content.empty()) {
bool all_whitespace = true;
for (char c : result.reasoning_content) {
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
all_whitespace = false;
break;
}
}
if (all_whitespace) {
result.reasoning_content.clear();
}
}
} }
void common_chat_peg_mapper::map(const common_peg_ast_node & node) { void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
@ -788,6 +802,16 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
return tool_choices; return tool_choices;
} }
common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const std::string & delimiter) {
if (s.empty()) {
return eps();
}
if (delimiter.empty()) {
return literal(s);
}
return literal(s.substr(0, s.rfind(delimiter)));
}
common_peg_parser common_chat_peg_builder::standard_json_tools( common_peg_parser common_chat_peg_builder::standard_json_tools(
const std::string & section_start, const std::string & section_start,
const std::string & section_end, const std::string & section_end,

View File

@ -82,6 +82,10 @@ class common_chat_peg_builder : public common_peg_parser_builder {
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); }
// Return a parser that parses the prefix of a string, up to a given delimiter.
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
// Legacy-compatible helper for building standard JSON tool calls // Legacy-compatible helper for building standard JSON tool calls
// Used by tests and manual parsers // Used by tests and manual parsers
// name_key/args_key: JSON key names for function name and arguments // name_key/args_key: JSON key names for function name and arguments

View File

@ -1,5 +1,6 @@
#include "chat.h" #include "chat.h"
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "common.h" #include "common.h"
@ -22,6 +23,7 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -760,7 +762,7 @@ static void foreach_parameter(const json &
std::string common_chat_template_direct_apply( std::string common_chat_template_direct_apply(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs, const autoparser::generation_params & inputs,
const std::optional<json> & messages_override, const std::optional<json> & messages_override,
const std::optional<json> & tools_override, const std::optional<json> & tools_override,
const std::optional<json> & additional_context) { const std::optional<json> & additional_context) {
@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
} }
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
@ -870,14 +872,14 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
}; };
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]");
auto reasoning = auto reasoning =
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
// Response format parser // Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences // Ministral wants to emit json surrounded by code fences
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) return generation_prompt + (reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```");
<< "```";
} }
// Tool call parser // Tool call parser
@ -897,12 +899,12 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; return generation_prompt + (reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls);
} }
// Content only parser // Content only parser
include_grammar = false; include_grammar = false;
return reasoning << p.content(p.rest()); return generation_prompt + (reasoning << p.content(p.rest()));
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -928,22 +930,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
} }
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
// Copy reasoning to the "thinking" field as expected by the gpt-oss template // Copy reasoning to the "thinking" field as expected by the gpt-oss template
auto adjusted_messages = json::array(); auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) { for (auto msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); msg["thinking"] = msg.at("reasoning_content");
if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
if (has_reasoning_content && has_tool_calls) { msg.erase("content");
auto adjusted_message = msg; }
adjusted_message["thinking"] = msg.at("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
} }
adjusted_messages.push_back(msg);
} }
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages); auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
@ -969,45 +968,31 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
}; };
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools; auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
const std::string END = "<|end|>"; auto start = p.rule("start", p.literal("<|start|>assistant"));
const std::string START = "<|start|>"; auto end = p.rule("end", p.literal("<|end|>"));
const std::string MESSAGE = "<|message|>"; auto content = p.rule("message-content", p.until("<|end|>"));
const std::string CHANNEL = "<|channel|>"; auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
const std::string CONSTRAIN = "<|constrain|>"; auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
const std::string START_ASSISTANT = START + "assistant";
const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis";
const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary";
const std::string CHANNEL_FINAL = CHANNEL + "final";
auto the_end = END | p.end(); auto analysis = p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
auto any = p.rule("any", preamble | analysis);
const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE; if (has_response_format) {
auto segment_content = p.until(END); auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
auto analysis_segment = extract_reasoning ? auto response_format = p.rule("response-format",
p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end : p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
p.content(analysis_header + p.until(END) + the_end); p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE }); return p.zero_or_more(start + analysis) + start + response_format;
auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) });
auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE +
p.content(segment_content) + the_end);
if (!inputs.json_schema.is_null()) {
auto final_header = p.literal(CHANNEL_FINAL);
auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content);
return p.optional(analysis_segment) + final_header + constraint + MESSAGE +
p.content(p.schema(p.json(), "response-format", inputs.json_schema));
} }
auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment });
auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end();
// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto tool_choice = p.choice(); auto tool_choice = p.choice();
@ -1016,42 +1001,37 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
std::string name = function.at("name"); std::string name = function.at("name");
const auto & params = function.at("parameters"); const auto & params = function.at("parameters");
// Tool call can appear as: auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
// 1. In role header: " to=functions.NAME<|channel|>..." auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type);
// 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..."
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS);
auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content);
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params)); auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
// Pattern 1: recipient in role header // recipient in role header
// " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS" // <|start|>assistant to=functions.NAME<|channel|>(commentary|analysis)[constraint]<|message|>ARGS
auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args); auto tool_in_role = p.tool(p.tool_open(func_name + channel + constraint + p.literal("<|message|>")) + args);
// Pattern 2: recipient in channel header // recipient in channel header
// "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS" // <|channel|>(commentary|analysis) to=functions.NAME[constraint]<|message|>ARGS
auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args); auto tool_in_channel = p.tool(p.tool_open(channel + func_name + constraint + p.literal("<|message|>")) + args);
tool_choice |= tool_in_role | tool_in_channel; tool_choice |= p.rule("tool-" + name, tool_in_role | tool_in_channel);
}); });
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; auto tool_call = p.trigger_rule("tool-call", tool_choice);
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT)); if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end()); return p.zero_or_more(start + any) + start + tool_call;
}
return p.choice({ p.trigger_rule("single-tool", tool_call), p.trigger_rule("tools", p.one_or_more(segment) + tool_call) }); return p.zero_or_more(start + any) + start + (tool_call | final_msg);
} }
return contents; return p.zero_or_more(start + any) + start + final_msg;
}); });
data.parser = parser.save(); data.parser = parser.save();
if (include_grammar) { if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function"); const auto & function = tool.at("function");
@ -1062,10 +1042,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
}); });
data.grammar_triggers = { data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" },
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" },
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" }
"(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" }
}; };
} }
@ -1074,7 +1053,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} // Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1095,13 +1074,14 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Build content parser for >>>all\n{content} // Build content parser for >>>all\n{content}
// When tools are present, content stops before the next ">>>" (tool call) // When tools are present, content stops before the next ">>>" (tool call)
// When no tools, content goes until end // When no tools, content goes until end
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>")); auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest()); auto content_until_end = p.literal("all\n") + p.content(p.rest());
auto generation_prompt = p.literal(inputs.generation_prompt);
// If no tools or tool_choice is NONE, just parse content // If no tools or tool_choice is NONE, just parse content
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
// When no tools, just match the prefix and capture everything after // When no tools, just match the prefix and capture everything after
return content_until_end + p.end(); return generation_prompt + content_until_end + p.end();
} }
// Build tool call parsers for each available function // Build tool call parsers for each available function
@ -1113,7 +1093,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Tool format: >>>function_name\n{json_args} // Tool format: >>>function_name\n{json_args}
auto tool_parser = p.tool( auto tool_parser = p.tool(
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) + p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
); );
@ -1124,17 +1104,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice)); auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
auto content_and_tools = content_until_tool + tools_only; auto content_and_tools = content_until_tool + tools_only;
auto ret = p.eps();
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
if (inputs.parallel_tool_calls) { if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, tools_only }) + p.end(); ret = p.choice({ content_and_tools, tools_only }) + p.end();
} else {
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
} }
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); } else if (inputs.parallel_tool_calls) {
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
} else {
auto content_and_tool = content_until_tool + tool_choice;
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
} }
if (inputs.parallel_tool_calls) { return generation_prompt + ret;
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
}
auto content_and_tool = content_until_tool + tool_choice;
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1164,14 +1147,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index> // Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
// The ID contains both the function name and an incrementing counter // The ID contains both the function name and an incrementing counter
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true; data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = { data.preserved_tokens = {
"<|tool_calls_section_begin|>", "<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>", "<|tool_calls_section_end|>",
@ -1186,6 +1167,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Kimi K2 Thinking format: // Kimi K2 Thinking format:
// - Reasoning: <think>{reasoning}</think> // - Reasoning: <think>{reasoning}</think>
@ -1197,16 +1190,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// <|tool_calls_section_end|> // <|tool_calls_section_end|>
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ... // The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
// Tool call markers // Tool call markers
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
auto end = p.end(); auto end = p.end();
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny. // Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
@ -1214,11 +1198,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning(
p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) +
p.optional(p.literal(THINK_END))) : p.eps(); p.optional(p.literal(THINK_END))) : p.eps();
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
// Content only parser (no tools) // Content only parser (no tools)
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end; return generation_prompt + reasoning + p.content(p.rest()) + end;
} }
// Build tool call parsers for each available function // Build tool call parsers for each available function
@ -1254,7 +1239,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
return reasoning + content_before_tools + tool_calls + end; return generation_prompt + reasoning + content_before_tools + tool_calls + end;
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1284,7 +1269,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|> // - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
// Tool calls can appear multiple times (parallel tool calls) // Tool calls can appear multiple times (parallel tool calls)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1303,13 +1288,16 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string TOOL_CALL_START = "<|tool_call_start|>"; const std::string TOOL_CALL_START = "<|tool_call_start|>";
const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string TOOL_CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>"; const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>"; const std::string THINK_END = "</think>";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
auto end = p.end(); auto end = p.end();
auto reasoning = p.eps(); auto reasoning = p.eps();
@ -1318,7 +1306,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
} }
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end; return generation_prompt + reasoning + p.content(p.rest()) + end;
} }
auto tool_calls = p.rule("tool-calls", auto tool_calls = p.rule("tool-calls",
@ -1330,7 +1318,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto content = p.content(p.until(TOOL_CALL_START)); auto content = p.content(p.until(TOOL_CALL_START));
return reasoning + content + tool_calls + end; return generation_prompt + reasoning + content + tool_calls + end;
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1356,7 +1344,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
static common_chat_params common_chat_params_init_gigachat_v3( static common_chat_params common_chat_params_init_gigachat_v3(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
@ -1370,9 +1358,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto ret = p.eps();
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// Build a choice of all available tools // Build a choice of all available tools
auto tool_choice = p.choice(); auto tool_choice = p.choice();
@ -1395,13 +1384,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice); auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls; ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
} else {
// Content only parser
include_grammar = false;
ret = p.content(p.rest());
} }
// Content only parser return p.literal(inputs.generation_prompt) + ret;
include_grammar = false;
return p.content(p.rest());
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1496,87 +1486,10 @@ static json common_chat_extra_context() {
return ctx; return ctx;
} }
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, static std::optional<common_chat_params> try_specialized_template(
const struct common_chat_templates_inputs & inputs) { const common_chat_template & tmpl,
autoparser::templates_params params; const std::string & src,
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const autoparser::generation_params & params) {
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
// params.parallel_tool_calls = false;
// } else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
//}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
return p.content(p.rest());
});
data.parser = parser.save();
return data;
}
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
@ -1617,14 +1530,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
// GigaChatV3 format detection // GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos && if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos && src.find("<|message_sep|>") != std::string::npos &&
src.find("<|function_call|>") == std::string::npos src.find("<|function_call|>") == std::string::npos) {
) {
LOG_DBG("Using specialized template: GigaChatV3\n"); LOG_DBG("Using specialized template: GigaChatV3\n");
return common_chat_params_init_gigachat_v3(tmpl, params); return common_chat_params_init_gigachat_v3(tmpl, params);
} }
return std::nullopt;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::generation_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl =
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;
params.add_generation_prompt = inputs.add_generation_prompt;
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
params.parallel_tool_calls = inputs.parallel_tool_calls;
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
return p.prefix(params.generation_prompt) + p.content(p.rest());
});
data.parser = parser.save();
return data;
}
if (auto result = try_specialized_template(tmpl, src, params)) {
result->generation_prompt = params.generation_prompt;
return *result;
}
try { try {
LOG_DBG("Using differential autoparser\n"); LOG_DBG("%s: using differential autoparser\n", __func__);
struct autoparser::autoparser autoparser; struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl); autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
@ -1632,13 +1636,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
if (auto_params.supports_thinking) { if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start; auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end; auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
} }
auto_params.generation_prompt = params.generation_prompt;
common_peg_arena arena;
arena.load(auto_params.parser);
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
return auto_params; return auto_params;
} catch (const std::exception & e) { } catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@ -1736,14 +1738,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
LOG_DBG("No parser definition detected, assuming pure content parser."); LOG_DBG("No parser definition detected, assuming pure content parser.");
} }
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); const std::string effective_input = params.generation_prompt.empty()
? input
: params.generation_prompt + input;
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT; common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
if (params.debug) { if (params.debug) {
flags |= COMMON_PEG_PARSE_FLAG_DEBUG; flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
} }
common_peg_parse_context ctx(input, flags); common_peg_parse_context ctx(effective_input, flags);
auto result = parser.parse(ctx); auto result = parser.parse(ctx);
if (result.fail()) { if (result.fail()) {
@ -1763,7 +1769,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
return msg; return msg;
} }
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
input.substr(result.end)); effective_input.substr(result.end));
} }
common_chat_msg msg; common_chat_msg msg;

View File

@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
struct common_chat_templates; struct common_chat_templates;
namespace autoparser { namespace autoparser {
struct templates_params; struct generation_params;
} // namespace autoparser } // namespace autoparser
struct common_chat_tool_call { struct common_chat_tool_call {
@ -212,7 +212,7 @@ struct common_chat_params {
std::string prompt; std::string prompt;
std::string grammar; std::string grammar;
bool grammar_lazy = false; bool grammar_lazy = false;
bool thinking_forced_open = false; std::string generation_prompt;
bool supports_thinking = false; bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>" std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>" std::string thinking_end_tag; // e.g., "</think>"
@ -229,14 +229,14 @@ struct common_chat_parser_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false; bool reasoning_in_content = false;
bool thinking_forced_open = false; std::string generation_prompt;
bool parse_tool_calls = true; bool parse_tool_calls = true;
bool debug = false; // Enable debug output for PEG parser bool debug = false; // Enable debug output for PEG parser
common_peg_arena parser = {}; common_peg_arena parser = {};
common_chat_parser_params() = default; common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) { common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format; format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open; generation_prompt = chat_params.generation_prompt;
} }
}; };
@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
std::string common_chat_template_direct_apply( std::string common_chat_template_direct_apply(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs, const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt, const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt, const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt); const std::optional<json> & additional_context = std::nullopt);

View File

@ -1067,7 +1067,7 @@ common_init_result::common_init_result(common_params & params) :
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
// load and optionally apply lora adapters (must be loaded before context creation) // load and optionally apply lora adapters
for (auto & la : params.lora_adapters) { for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora; llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str())); lora.reset(llama_adapter_lora_init(model, la.path.c_str()));

View File

@ -3,12 +3,14 @@
#pragma once #pragma once
#include "ggml-opt.h" #include "ggml-opt.h"
#include "ggml.h"
#include "llama-cpp.h" #include "llama-cpp.h"
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <variant>
#include <vector> #include <vector>
#include <map> #include <map>
@ -178,6 +180,43 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
}; };
// Grammar type enumeration
enum common_grammar_type {
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
};
// Grammar variant struct with type and grammar string
struct common_grammar {
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
std::string grammar;
// Default constructor - no grammar
common_grammar() = default;
// Constructor with type and grammar string
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
}
// Check if a grammar is set
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};
// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
return g.grammar;
}
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}
// sampling parameters // sampling parameters
struct common_params_sampling { struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -228,7 +267,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE, COMMON_SAMPLER_TYPE_TEMPERATURE,
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
bool grammar_lazy = false; bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars) std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens; std::set<llama_token> preserved_tokens;
@ -236,10 +275,15 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// The assistant generation prompt already prefilled into the prompt.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
// Only applied when the grammar is of output-format or tool-calls type.
std::string generation_prompt;
// reasoning budget sampler parameters // reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params // these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag) std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)

View File

@ -1,9 +1,9 @@
#include "arg.h" #include "arg.h"
#include "common.h" #include "common.h"
#include "gguf.h" // for reading GGUF splits
#include "log.h" #include "log.h"
#include "download.h" #include "download.h"
#include "hf-cache.h"
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
@ -15,6 +15,7 @@
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <regex> #include <regex>
#include <unordered_set>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
@ -35,8 +36,6 @@
#endif #endif
#endif #endif
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
// isatty // isatty
#if defined(_WIN32) #if defined(_WIN32)
#include <io.h> #include <io.h>
@ -51,31 +50,6 @@ using json = nlohmann::ordered_json;
// //
// validate repo name format: owner/repo // validate repo name format: owner/repo
static bool validate_repo_name(const std::string & repo) {
static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
return std::regex_match(repo, repo_regex);
}
static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
// we use "=" to avoid clashing with other component, while still being allowed on windows
std::string fname = "manifest=" + repo + "=" + tag + ".json";
if (!validate_repo_name(repo)) {
throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
}
string_replace_all(fname, "/", "=");
return fs_get_cache_file(fname);
}
static std::string read_file(const std::string & fname) {
std::ifstream file(fname);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
file.close();
return content;
}
static void write_file(const std::string & fname, const std::string & content) { static void write_file(const std::string & fname, const std::string & content) {
const std::string fname_tmp = fname + ".tmp"; const std::string fname_tmp = fname + ".tmp";
std::ofstream file(fname_tmp); std::ofstream file(fname_tmp);
@ -132,7 +106,7 @@ static bool is_http_status_ok(int status) {
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) { std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':'); auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string tag = parts.size() > 1 ? parts.back() : "";
std::string hf_repo = parts[0]; std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) { if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n"); throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
@ -290,7 +264,8 @@ static bool common_pull_file(httplib::Client & cli,
static int common_download_file_single_online(const std::string & url, static int common_download_file_single_online(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const std::string & bearer_token,
const common_header_list & custom_headers) { const common_header_list & custom_headers,
bool skip_etag = false) {
static const int max_attempts = 3; static const int max_attempts = 3;
static const int retry_delay_seconds = 2; static const int retry_delay_seconds = 2;
@ -310,6 +285,11 @@ static int common_download_file_single_online(const std::string & url,
const bool file_exists = std::filesystem::exists(path); const bool file_exists = std::filesystem::exists(path);
if (file_exists && skip_etag) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
std::string last_etag; std::string last_etag;
if (file_exists) { if (file_exists) {
last_etag = read_etag(path); last_etag = read_etag(path);
@ -361,6 +341,12 @@ static int common_download_file_single_online(const std::string & url,
} }
} }
{ // silent
std::error_code ec;
std::filesystem::path p(path);
std::filesystem::create_directories(p.parent_path(), ec);
}
const std::string path_temporary = path + ".downloadInProgress"; const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds; int delay = retry_delay_seconds;
@ -391,7 +377,7 @@ static int common_download_file_single_online(const std::string & url,
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1; return -1;
} }
if (!etag.empty()) { if (!etag.empty() && !skip_etag) {
write_etag(path, etag); write_etag(path, etag);
} }
return head->status; return head->status;
@ -440,9 +426,10 @@ int common_download_file_single(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const std::string & bearer_token,
bool offline, bool offline,
const common_header_list & headers) { const common_header_list & headers,
bool skip_etag) {
if (!offline) { if (!offline) {
return common_download_file_single_online(url, path, bearer_token, headers); return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
} }
if (!std::filesystem::exists(path)) { if (!std::filesystem::exists(path)) {
@ -454,193 +441,293 @@ int common_download_file_single(const std::string & url,
return 304; // Not Modified - fake cached response return 304; // Not Modified - fake cached response
} }
// download multiple files from remote URLs to local paths struct gguf_split_info {
// the input is a vector of pairs <url, path> std::string prefix; // tag included
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, std::string tag;
const std::string & bearer_token, int index;
bool offline, int count;
const common_header_list & headers) { };
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
futures_download.reserve(urls.size());
for (auto const & item : urls) { static gguf_split_info get_gguf_split_info(const std::string & path) {
futures_download.push_back( static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
std::async( static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
std::launch::async, std::smatch m;
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers); std::string prefix = path;
return is_http_status_ok(http_status); string_remove_suffix(prefix, ".gguf");
},
item int index = 1;
) int count = 1;
);
if (std::regex_match(prefix, m, re_split)) {
index = std::stoi(m[2].str());
count = std::stoi(m[3].str());
prefix = m[1].str();
} }
// Wait for all downloads to complete std::string tag;
for (auto & f : futures_download) { if (std::regex_search(prefix, m, re_tag)) {
if (!f.get()) { tag = m[1].str();
return false; for (char & c : tag) {
c = std::toupper((unsigned char)c);
} }
} }
return true; return {std::move(prefix), std::move(tag), index, count};
} }
bool common_download_model(const common_params_model & model, // Q4_0 -> 4, F16 -> 16, NVFP4 -> 4, Q8_K_M -> 8, etc
const std::string & bearer_token, static int extract_quant_bits(const std::string & filename) {
bool offline, auto split = get_gguf_split_info(filename);
const common_header_list & headers) {
// Basic validation of the model.url auto pos = split.tag.find_first_of("0123456789");
if (model.url.empty()) { if (pos == std::string::npos) {
LOG_ERR("%s: invalid model url\n", __func__); return 0;
return false;
} }
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers); return std::stoi(split.tag.substr(pos));
if (!is_http_status_ok(http_status)) {
return false;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
return false;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
return false;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
return false;
}
}
std::vector<std::pair<std::string, std::string>> urls;
for (int idx = 1; idx < n_split; idx++) {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
char split_url[LLAMA_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
if (std::string(split_path) == model.path) {
continue; // skip the already downloaded file
}
urls.push_back({split_url, split_path});
}
// Download in parallel
common_download_file_multiple(urls, bearer_token, offline, headers);
}
return true;
} }
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
const std::string & bearer_token, const hf_cache::hf_file & file) {
bool offline, auto split = get_gguf_split_info(file.path);
const common_header_list & custom_headers) {
// the returned hf_repo is without tag
auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; if (split.count <= 1) {
return {file};
// headers
common_header_list headers = custom_headers;
headers.push_back({"Accept", "application/json"});
if (!bearer_token.empty()) {
headers.push_back({"Authorization", "Bearer " + bearer_token});
} }
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response hf_cache::hf_files result;
// User-Agent header is already set in common_remote_get_content, no need to set it here
// make the request for (const auto & f : files) {
common_remote_params params; auto split_f = get_gguf_split_info(f.path);
params.headers = headers; if (split_f.count == split.count && split_f.prefix == split.prefix) {
long res_code = 0; result.push_back(f);
std::string res_str;
bool use_cache = false;
std::string cached_response_path = get_manifest_path(hf_repo, tag);
if (!offline) {
try {
auto res = common_remote_get_content(url, params);
res_code = res.first;
res_str = std::string(res.second.data(), res.second.size());
} catch (const std::exception & e) {
LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
} }
} }
if (res_code == 0) { return result;
if (std::filesystem::exists(cached_response_path)) { }
LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
res_str = read_file(cached_response_path); static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
res_code = 200; const std::string & model) {
use_cache = true; hf_cache::hf_file best;
} else { size_t best_depth = 0;
throw std::runtime_error( int best_diff = 0;
offline ? "error: failed to get manifest (offline mode)" bool found = false;
: "error: failed to get manifest (check your internet connection)");
auto model_bits = extract_quant_bits(model);
auto model_parts = string_split<std::string>(model, '/');
auto model_dir = model_parts.end() - 1;
for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
continue;
}
auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
continue;
}
size_t depth = dir - mmproj_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);
if (!found || depth > best_depth || (depth == best_depth && diff < best_diff)) {
best = f;
best_depth = depth;
best_diff = diff;
found = true;
} }
} }
std::string ggufFile; return best;
std::string mmprojFile; }
if (res_code == 200 || res_code == 304) { static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
try { const std::string & tag) {
auto j = json::parse(res_str); std::vector<std::string> tags;
if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) { if (!tag.empty()) {
ggufFile = j["ggufFile"]["rfilename"].get<std::string>(); tags.push_back(tag);
}
if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
}
} catch (const std::exception & e) {
throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
}
if (!use_cache) {
// if not using cached response, update the cache file
write_file(cached_response_path, res_str);
}
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else { } else {
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str())); tags = {"Q4_K_M", "Q4_0"};
} }
// check response for (const auto & t : tags) {
if (ggufFile.empty()) { std::regex pattern(t + "[.-]", std::regex::icase);
throw std::runtime_error("error: model does not have ggufFile"); for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf") &&
f.path.find("mmproj") == std::string::npos &&
std::regex_search(f.path, pattern)) {
return f;
}
}
} }
return { hf_repo, ggufFile, mmprojFile }; for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf") &&
f.path.find("mmproj") == std::string::npos) {
return f;
}
}
return {};
}
static void list_available_gguf_files(const hf_cache::hf_files & files) {
LOG_INF("Available GGUF files:\n");
for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf")) {
LOG_INF(" - %s\n", f.path.c_str());
}
}
}
struct hf_plan {
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
};
static hf_plan get_hf_plan(const common_params_model & model,
const std::string & token,
const common_download_model_opts & opts) {
hf_plan plan;
hf_cache::hf_files all;
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
if (!opts.offline) {
all = hf_cache::get_repo_files(repo, token);
}
if (all.empty()) {
all = hf_cache::get_cached_files(repo);
}
if (all.empty()) {
return plan;
}
hf_cache::hf_file primary;
if (!model.hf_file.empty()) {
for (const auto & f : all) {
if (f.path == model.hf_file) {
primary = f;
break;
}
}
if (primary.path.empty()) {
LOG_ERR("%s: file '%s' not found in repository\n", __func__, model.hf_file.c_str());
list_available_gguf_files(all);
return plan;
}
} else {
primary = find_best_model(all, tag);
if (primary.path.empty()) {
LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str());
list_available_gguf_files(all);
return plan;
}
}
plan.model_files = get_split_files(all, primary);
if (opts.download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path);
}
return plan;
}
struct download_task {
std::string url;
std::string path;
};
static std::vector<download_task> get_url_tasks(const common_params_model & model) {
auto split = get_gguf_split_info(model.url);
if (split.count <= 1) {
return {{model.url, model.path}};
}
auto filename = split.prefix;
if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
filename = split.prefix.substr(pos + 1);
}
auto parent_path = std::filesystem::path(model.path).parent_path();
auto prefix_path = (parent_path / filename).string();
std::vector<download_task> tasks;
for (int i = 1; i <= split.count; i++) {
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
tasks.push_back({split.prefix + suffix, prefix_path + suffix});
}
return tasks;
}
common_download_model_result common_download_model(const common_params_model & model,
const std::string & bearer_token,
const common_download_model_opts & opts,
const common_header_list & headers) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
hf = get_hf_plan(model, bearer_token, opts);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
result.model_path = model.path;
return result;
}
if (tasks.empty()) {
return result;
}
std::vector<std::future<bool>> futures;
for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async,
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
return is_http_status_ok(status);
}
));
}
for (auto & f : futures) {
if (!f.get()) {
return {};
}
}
if (is_hf) {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.model_files[0].final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
} else {
result.model_path = model.path;
}
return result;
} }
// //
@ -765,28 +852,21 @@ std::string common_docker_resolve_model(const std::string & docker) {
} }
std::vector<common_cached_model_info> common_list_cached_models() { std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models; std::unordered_set<std::string> seen;
const std::string cache_dir = fs_get_cache_directory(); std::vector<common_cached_model_info> result;
const std::vector<common_file_info> files = fs_list(cache_dir, false);
for (const auto & file : files) { auto files = hf_cache::get_cached_files();
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info; for (const auto & f : files) {
model_info.manifest_path = file.path; auto split = get_gguf_split_info(f.path);
std::string fname = file.name; if (split.index != 1 || split.tag.empty() ||
string_replace_all(fname, ".json", ""); // remove extension split.prefix.find("mmproj") != std::string::npos) {
auto parts = string_split<std::string>(fname, '='); continue;
if (parts.size() == 4) { }
// expect format: manifest=<user>=<model>=<tag>=<other> if (seen.insert(f.repo_id + ":" + split.tag).second) {
model_info.user = parts[1]; result.push_back({f.repo_id, split.tag});
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
} }
} }
return models;
return result;
} }

View File

@ -17,54 +17,60 @@ struct common_remote_params {
// get remote file content, returns <http_code, raw_response_body> // get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params); std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
// split HF repo with tag into <repo, tag> // split HF repo with tag into <repo, tag>, for example:
// for example: "user/model:tag" -> <"user/model", "tag"> // - "ggml-org/models:F16" -> <"ggml-org/models", "F16">
// if tag is not present, default to "latest" // tag is optional and can be empty
// example: "user/model" -> <"user/model", "latest">
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag); std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
// Result of common_list_cached_models
struct common_cached_model_info { struct common_cached_model_info {
std::string manifest_path; std::string repo;
std::string user;
std::string model;
std::string tag; std::string tag;
size_t size = 0; // GGUF size in bytes
// return string representation like "user/model:tag"
// if tag is "latest", it will be omitted
std::string to_string() const { std::string to_string() const {
return user + "/" + model + (tag == "latest" ? "" : ":" + tag); return repo + ":" + tag;
} }
}; };
struct common_hf_file_res { // Options for common_download_model
std::string repo; // repo name with ":tag" removed struct common_download_model_opts {
std::string ggufFile; bool download_mmproj = false;
std::string mmprojFile; bool offline = false;
}; };
/** // Result of common_download_model
* Allow getting the HF file from the HF repo with tag (like ollama), for example: struct common_download_model_result {
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 std::string model_path;
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M std::string mmproj_path;
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s };
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
common_hf_file_res common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {}
);
// returns true if download succeeded // Download model from HuggingFace repo or URL
bool common_download_model( //
// input (via model struct):
// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
// - model.hf_file: specific file in the repo (requires hf_repo)
// - model.url: simple download (used if hf_repo is empty)
// - model.path: local file path
//
// tag matching (for HF repos without model.hf_file):
// - if tag is specified, searches for GGUF matching that quantization
// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
//
// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
// detected and all parts are downloaded
//
// caching:
// - HF repos: uses HuggingFace cache
// - URLs: uses ETag-based caching
//
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
//
// returns result with model_path and mmproj_path (empty on failure)
common_download_model_result common_download_model(
const common_params_model & model, const common_params_model & model,
const std::string & bearer_token, const std::string & bearer_token,
bool offline, const common_download_model_opts & opts = {},
const common_header_list & headers = {} const common_header_list & headers = {}
); );
@ -73,11 +79,13 @@ std::vector<common_cached_model_info> common_list_cached_models();
// download single file from url to local path // download single file from url to local path
// returns status code or -1 on error // returns status code or -1 on error
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
int common_download_file_single(const std::string & url, int common_download_file_single(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const std::string & bearer_token,
bool offline, bool offline,
const common_header_list & headers = {}); const common_header_list & headers = {},
bool skip_etag = false);
// resolve and download model from Docker registry // resolve and download model from Docker registry
// return local path to downloaded model file // return local path to downloaded model file

644
common/hf-cache.cpp Normal file
View File

@ -0,0 +1,644 @@
#include "hf-cache.h"
#include "common.h"
#include "log.h"
#include "http.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <filesystem>
#include <fstream>
#include <atomic>
#include <regex> // migration only
#include <string>
#include <string_view>
#include <stdexcept>
namespace nl = nlohmann;
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#define HOME_DIR "USERPROFILE"
#include <windows.h>
#else
#define HOME_DIR "HOME"
#endif
namespace hf_cache {
namespace fs = std::filesystem;
static fs::path get_cache_directory() {
static const fs::path cache = []() {
struct {
const char * var;
fs::path path;
} entries[] = {
{"HF_HUB_CACHE", fs::path()},
{"HUGGINGFACE_HUB_CACHE", fs::path()},
{"HF_HOME", fs::path("hub")},
{"XDG_CACHE_HOME", fs::path("huggingface") / "hub"},
{HOME_DIR, fs::path(".cache") / "huggingface" / "hub"}
};
for (const auto & entry : entries) {
if (auto * p = std::getenv(entry.var); p && *p) {
fs::path base(p);
return entry.path.empty() ? base : base / entry.path;
}
}
throw std::runtime_error("Failed to determine HF cache directory");
}();
return cache;
}
static std::string folder_name_to_repo(const std::string & folder) {
constexpr std::string_view prefix = "models--";
if (folder.rfind(prefix, 0)) {
return {};
}
std::string result = folder.substr(prefix.length());
string_replace_all(result, "--", "/");
return result;
}
static std::string repo_to_folder_name(const std::string & repo_id) {
constexpr std::string_view prefix = "models--";
std::string result = std::string(prefix) + repo_id;
string_replace_all(result, "/", "--");
return result;
}
static fs::path get_repo_path(const std::string & repo_id) {
return get_cache_directory() / repo_to_folder_name(repo_id);
}
static bool is_hex_char(const char c) {
return (c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f') ||
(c >= '0' && c <= '9');
}
static bool is_hex_string(const std::string & s, size_t expected_len) {
if (s.length() != expected_len) {
return false;
}
for (const char c : s) {
if (!is_hex_char(c)) {
return false;
}
}
return true;
}
static bool is_alphanum(const char c) {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9');
}
static bool is_special_char(char c) {
return c == '/' || c == '.' || c == '-';
}
// base chars [A-Za-z0-9_] are always valid
// special chars [/.-] must be surrounded by base chars
// exactly one '/' required
static bool is_valid_repo_id(const std::string & repo_id) {
if (repo_id.empty() || repo_id.length() > 256) {
return false;
}
int slash = 0;
bool special = true;
for (const char c : repo_id) {
if (is_alphanum(c) || c == '_') {
special = false;
} else if (is_special_char(c)) {
if (special) {
return false;
}
slash += (c == '/');
special = true;
} else {
return false;
}
}
return !special && slash == 1;
}
static bool is_valid_hf_token(const std::string & token) {
if (token.length() < 37 || token.length() > 256 ||
!string_starts_with(token, "hf_")) {
return false;
}
for (size_t i = 3; i < token.length(); ++i) {
if (!is_alphanum(token[i])) {
return false;
}
}
return true;
}
static bool is_valid_commit(const std::string & hash) {
return is_hex_string(hash, 40);
}
static bool is_valid_oid(const std::string & oid) {
return is_hex_string(oid, 40) || is_hex_string(oid, 64);
}
static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) {
if (subpath.is_absolute()) {
return false; // never do a / b with b absolute
}
auto b = fs::absolute(path).lexically_normal();
auto t = (b / subpath).lexically_normal();
auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end());
return b_end == b.end();
}
static void safe_write_file(const fs::path & path, const std::string & data) {
fs::path path_tmp = path.string() + ".tmp";
if (path.has_parent_path()) {
fs::create_directories(path.parent_path());
}
std::ofstream file(path_tmp);
file << data;
file.close();
std::error_code ec;
if (!file.fail()) {
fs::rename(path_tmp, path, ec);
}
if (file.fail() || ec) {
fs::remove(path_tmp, ec);
throw std::runtime_error("failed to write file: " + path.string());
}
}
static nl::json api_get(const std::string & url,
const std::string & token) {
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {
{"User-Agent", "llama-cpp/" + build_info},
{"Accept", "application/json"}
};
if (is_valid_hf_token(token)) {
headers.emplace("Authorization", "Bearer " + token);
} else if (!token.empty()) {
LOG_WRN("%s: invalid token, authentication disabled\n", __func__);
}
if (auto res = cli.Get(parts.path, headers)) {
auto body = res->body;
if (res->status == 200) {
return nl::json::parse(res->body);
}
try {
body = nl::json::parse(res->body)["error"].get<std::string>();
} catch (...) { }
throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body);
} else {
throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error()));
}
}
static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) {
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
if (!json.is_object() ||
!json.contains("branches") || !json["branches"].is_array()) {
LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str());
return {};
}
fs::path refs_path = get_repo_path(repo_id) / "refs";
std::string name;
std::string commit;
for (const auto & branch : json["branches"]) {
if (!branch.is_object() ||
!branch.contains("name") || !branch["name"].is_string() ||
!branch.contains("targetCommit") || !branch["targetCommit"].is_string()) {
continue;
}
std::string _name = branch["name"].get<std::string>();
std::string _commit = branch["targetCommit"].get<std::string>();
if (!is_valid_subpath(refs_path, _name)) {
LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str());
continue;
}
if (!is_valid_commit(_commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str());
continue;
}
if (_name == "main") {
name = _name;
commit = _commit;
break;
}
if (name.empty() || commit.empty()) {
name = _name;
commit = _commit;
}
}
if (name.empty() || commit.empty()) {
LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str());
return {};
}
safe_write_file(refs_path / name, commit);
return commit;
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return {};
}
hf_files get_repo_files(const std::string & repo_id,
const std::string & token) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
std::string commit = get_repo_commit(repo_id, token);
if (commit.empty()) {
LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str());
return {};
}
fs::path blobs_path = get_repo_path(repo_id) / "blobs";
fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit;
hf_files files;
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
if (!json.is_array()) {
LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str());
return {};
}
for (const auto & item : json) {
if (!item.is_object() ||
!item.contains("type") || !item["type"].is_string() || item["type"] != "file" ||
!item.contains("path") || !item["path"].is_string()) {
continue;
}
hf_file file;
file.repo_id = repo_id;
file.path = item["path"].get<std::string>();
if (!is_valid_subpath(commit_path, file.path)) {
LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str());
continue;
}
if (item.contains("lfs") && item["lfs"].is_object()) {
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
file.oid = item["lfs"]["oid"].get<std::string>();
}
} else if (item.contains("oid") && item["oid"].is_string()) {
file.oid = item["oid"].get<std::string>();
}
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
continue;
}
file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path;
fs::path final_path = commit_path / file.path;
file.final_path = final_path.string();
if (!file.oid.empty() && !fs::exists(final_path)) {
fs::path local_path = blobs_path / file.oid;
file.local_path = local_path.string();
} else {
file.local_path = file.final_path;
}
files.push_back(file);
}
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return files;
}
static std::string get_cached_ref(const fs::path & repo_path) {
fs::path refs_path = repo_path / "refs";
if (!fs::is_directory(refs_path)) {
return {};
}
std::string fallback;
for (const auto & entry : fs::directory_iterator(refs_path)) {
if (!entry.is_regular_file()) {
continue;
}
std::ifstream f(entry.path());
std::string commit;
if (!f || !std::getline(f, commit) || commit.empty()) {
continue;
}
if (!is_valid_commit(commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str());
continue;
}
if (entry.path().filename() == "main") {
return commit;
}
if (fallback.empty()) {
fallback = commit;
}
}
return fallback;
}
hf_files get_cached_files(const std::string & repo_id) {
fs::path cache_dir = get_cache_directory();
if (!fs::exists(cache_dir)) {
return {};
}
if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
hf_files files;
for (const auto & repo : fs::directory_iterator(cache_dir)) {
if (!repo.is_directory()) {
continue;
}
fs::path snapshots_path = repo.path() / "snapshots";
if (!fs::exists(snapshots_path)) {
continue;
}
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
if (!is_valid_repo_id(_repo_id)) {
continue;
}
if (!repo_id.empty() && _repo_id != repo_id) {
continue;
}
std::string commit = get_cached_ref(repo.path());
fs::path commit_path = snapshots_path / commit;
if (commit.empty() || !fs::is_directory(commit_path)) {
continue;
}
for (const auto & entry : fs::recursive_directory_iterator(commit_path)) {
if (!entry.is_regular_file() && !entry.is_symlink()) {
continue;
}
fs::path path = entry.path().lexically_relative(commit_path);
if (!path.empty()) {
hf_file file;
file.repo_id = _repo_id;
file.path = path.generic_string();
file.local_path = entry.path().string();
file.final_path = file.local_path;
files.push_back(std::move(file));
}
}
}
return files;
}
std::string finalize_file(const hf_file & file) {
static std::atomic<bool> symlinks_disabled{false};
std::error_code ec;
fs::path local_path(file.local_path);
fs::path final_path(file.final_path);
if (local_path == final_path || fs::exists(final_path, ec)) {
return file.final_path;
}
if (!fs::exists(local_path, ec)) {
return file.final_path;
}
fs::create_directories(final_path.parent_path(), ec);
if (!symlinks_disabled) {
fs::path target = fs::relative(local_path, final_path.parent_path(), ec);
if (!ec) {
fs::create_symlink(target, final_path, ec);
}
if (!ec) {
return file.final_path;
}
}
if (!symlinks_disabled.exchange(true)) {
LOG_WRN("%s: failed to create symlink: %s\n", __func__, ec.message().c_str());
LOG_WRN("%s: switching to degraded mode\n", __func__);
}
fs::rename(local_path, final_path, ec);
if (ec) {
LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str());
fs::copy(local_path, final_path, ec);
if (ec) {
LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str());
}
}
return file.final_path;
}
// delete everything after this line, one day
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
std::smatch match;
if (std::regex_match(filename, match, re)) {
return {match[1].str(), match[2].str()};
}
return {};
}
static std::string make_old_cache_filename(const std::string & owner,
const std::string & repo,
const std::string & filename) {
auto result = owner + "_" + repo + "_" + filename;
string_replace_all(result, "/", "_");
return result;
}
static bool migrate_single_file(const fs::path & old_cache,
const std::string & owner,
const std::string & repo,
const nl::json & node,
const hf_files & files) {
if (!node.contains("rfilename") ||
!node.contains("lfs") ||
!node["lfs"].contains("sha256")) {
return false;
}
std::string path = node["rfilename"];
std::string sha256 = node["lfs"]["sha256"];
const hf_file * file_info = nullptr;
for (const auto & f : files) {
if (f.path == path) {
file_info = &f;
break;
}
}
std::string old_filename = make_old_cache_filename(owner, repo, path);
fs::path old_path = old_cache / old_filename;
fs::path etag_path = old_path.string() + ".etag";
if (!fs::exists(old_path)) {
if (fs::exists(etag_path)) {
LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str());
fs::remove(etag_path);
}
return false;
}
bool delete_old_path = false;
if (!file_info) {
LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str());
delete_old_path = true;
} else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) {
LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str());
delete_old_path = true;
}
std::error_code ec;
if (delete_old_path) {
fs::remove(old_path, ec);
fs::remove(etag_path, ec);
return true;
}
fs::path new_path(file_info->local_path);
fs::create_directories(new_path.parent_path(), ec);
if (!fs::exists(new_path, ec)) {
fs::rename(old_path, new_path, ec);
if (ec) {
fs::copy_file(old_path, new_path, ec);
if (ec) {
LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str());
return false;
}
}
fs::remove(old_path, ec);
}
fs::remove(etag_path, ec);
std::string filename = finalize_file(*file_info);
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
return true;
}
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
fs::path old_cache = fs_get_cache_directory();
if (!fs::exists(old_cache)) {
return;
}
if (offline) {
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
return; // -hf is not going to work
}
bool warned = false;
for (const auto & entry : fs::directory_iterator(old_cache)) {
if (!entry.is_regular_file()) {
continue;
}
auto filename = entry.path().filename().string();
auto [owner, repo] = parse_manifest_name(filename);
if (owner.empty() || repo.empty()) {
continue;
}
if (!warned) {
warned = true;
LOG_WRN("================================================================================\n"
"WARNING: Migrating cache to HuggingFace cache directory\n"
" Old cache: %s\n"
" New cache: %s\n"
"This one-time migration moves models previously downloaded with -hf\n"
"from the legacy llama.cpp cache to the standard HuggingFace cache.\n"
"Models downloaded with --model-url are not affected.\n"
"================================================================================\n",
old_cache.string().c_str(), get_cache_directory().string().c_str());
}
auto repo_id = owner + "/" + repo;
auto files = get_repo_files(repo_id, token);
if (files.empty()) {
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
continue;
}
try {
std::ifstream manifest(entry.path());
auto json = nl::json::parse(manifest);
for (const char * key : {"ggufFile", "mmprojFile"}) {
if (json.contains(key)) {
migrate_single_file(old_cache, owner, repo, json[key], files);
}
}
} catch (const std::exception & e) {
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
continue;
}
fs::remove(entry.path());
}
}
} // namespace hf_cache

35
common/hf-cache.h Normal file
View File

@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <vector>
// Ref: https://huggingface.co/docs/hub/local-cache.md
namespace hf_cache {
struct hf_file {
std::string path;
std::string url;
std::string local_path;
std::string final_path;
std::string oid;
std::string repo_id;
};
using hf_files = std::vector<hf_file>;
// Get files from HF API
hf_files get_repo_files(
const std::string & repo_id,
const std::string & token
);
hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
} // namespace hf_cache

View File

@ -53,6 +53,13 @@ private:
return tokens[current + offset]; return tokens[current + offset];
} }
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) { token expect(token::type type, const std::string& error) {
const auto & t = peek(); const auto & t = peek();
if (t.t != type) { if (t.t != type) {
@ -90,9 +97,9 @@ private:
size_t start_pos = current; size_t start_pos = current;
switch (peek().t) { switch (peek().t) {
case token::comment: case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value); return mk_stmt<comment_statement>(start_pos, next().value);
case token::text: case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value); return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement: case token::open_statement:
return parse_jinja_statement(); return parse_jinja_statement();
case token::open_expression: case token::open_expression:
@ -119,8 +126,7 @@ private:
} }
size_t start_pos = current; size_t start_pos = current;
std::string name = peek().value; std::string name = next().value;
current++; // consume identifier
statement_ptr result; statement_ptr result;
if (name == "set") { if (name == "set") {
@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific) // Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information. // See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos); result = mk_stmt<noop_statement>(start_pos);
current++; ++current;
} else { } else {
throw std::runtime_error("Unknown statement: " + name); throw std::runtime_error("Unknown statement: " + name);
@ -217,7 +223,7 @@ private:
statements body; statements body;
if (is(token::equals)) { if (is(token::equals)) {
current++; ++current;
value = parse_expression_sequence(); value = parse_expression_sequence();
} else { } else {
// parsing multiline set here // parsing multiline set here
@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma); bool is_tuple = is(token::comma);
while (is(token::comma)) { while (is(token::comma)) {
current++; // consume comma ++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
} }
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]); return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages` // e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++; ++current; // consume 'in'
// `messages` in `for message in messages` // `messages` in `for message in messages`
auto iterable = parse_expression(); auto iterable = parse_expression();
@ -305,7 +311,8 @@ private:
} }
if (is_statement({"else"})) { if (is_statement({"else"})) {
current += 2; ++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}"); expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) { while (!is_statement({"endfor"})) {
alternate.push_back(parse_any()); alternate.push_back(parse_any());
@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression(); auto left = parse_logical_and_expression();
while (is_identifier("or")) { while (is_identifier("or")) {
size_t start_pos = current; size_t start_pos = current;
token op = tokens[current++]; token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
} }
return left; return left;
@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression(); auto left = parse_logical_negation_expression();
while (is_identifier("and")) { while (is_identifier("and")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
} }
return left; return left;
@ -367,7 +374,7 @@ private:
// Try parse unary operators // Try parse unary operators
if (is_identifier("not")) { if (is_identifier("not")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression()); return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
} }
return parse_comparison_expression(); return parse_comparison_expression();
@ -382,11 +389,12 @@ private:
size_t start_pos = current; size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos}; op = {token::identifier, "not in", tokens[current].pos};
current += 2; ++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) { } else if (is_identifier("in")) {
op = tokens[current++]; op = next();
} else if (is(token::comparison_binary_operator)) { } else if (is(token::comparison_binary_operator)) {
op = tokens[current++]; op = next();
} else break; } else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
} }
@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression(); auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) { while (is(token::additive_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
} }
return left; return left;
@ -407,7 +415,7 @@ private:
auto left = parse_test_expression(); auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) { while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
} }
return left; return left;
@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression(); auto operand = parse_filter_expression();
while (is_identifier("is")) { while (is_identifier("is")) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume 'is'
bool negate = false; bool negate = false;
if (is_identifier("not")) { current++; negate = true; } if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression(); auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3 // FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression(); auto operand = parse_call_member_expression();
while (is(token::pipe)) { while (is(token::pipe)) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume pipe
auto filter = parse_primary_expression(); auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter)); operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) { statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current; size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) { while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++]; auto op = next();
bool computed = op.t == token::open_square_bracket; bool computed = op.t == token::open_square_bracket;
statement_ptr prop; statement_ptr prop;
if (computed) { if (computed) {
@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() { statement_ptr parse_primary_expression() {
size_t start_pos = current; size_t start_pos = current;
auto t = tokens[current++]; auto t = next();
switch (t.t) { switch (t.t) {
case token::numeric_literal: case token::numeric_literal:
if (t.value.find('.') != std::string::npos) { if (t.value.find('.') != std::string::npos) {
@ -547,7 +555,7 @@ private:
case token::string_literal: { case token::string_literal: {
std::string val = t.value; std::string val = t.value;
while (is(token::string_literal)) { while (is(token::string_literal)) {
val += tokens[current++].value; val += next().value;
} }
return mk_stmt<string_literal>(start_pos, val); return mk_stmt<string_literal>(start_pos, val);
} }
@ -562,9 +570,9 @@ private:
statements vals; statements vals;
while (!is(token::close_square_bracket)) { while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression()); vals.push_back(parse_expression());
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<array_literal>(start_pos, std::move(vals)); return mk_stmt<array_literal>(start_pos, std::move(vals));
} }
case token::open_curly_bracket: { case token::open_curly_bracket: {
@ -573,9 +581,9 @@ private:
auto key = parse_expression(); auto key = parse_expression();
expect(token::colon, "Expected :"); expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()}); pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs)); return mk_stmt<object_literal>(start_pos, std::move(pairs));
} }
default: default:

View File

@ -451,7 +451,7 @@ struct value_array_t : public value_t {
} }
protected: protected:
virtual bool equivalent(const value_t & other) const override { virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence()); return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
} }
}; };
using value_array = std::shared_ptr<value_array_t>; using value_array = std::shared_ptr<value_array_t>;
@ -587,7 +587,7 @@ struct value_object_t : public value_t {
} }
protected: protected:
virtual bool equivalent(const value_t & other) const override { virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence()); return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
} }
}; };
using value_object = std::shared_ptr<value_object_t>; using value_object = std::shared_ptr<value_object_t>;

View File

@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->force_pos = 0; ctx->force_pos = 0;
} }
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init( return common_reasoning_budget_init_state(
ctx->vocab, ctx->vocab,
ctx->start_matcher.tokens, ctx->start_matcher.tokens,
ctx->end_matcher.tokens, ctx->end_matcher.tokens,
@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
/* .backend_set_input = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * common_reasoning_budget_init( static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, int32_t budget,
common_reasoning_budget_state initial_state) { common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING // promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING; initial_state = REASONING_BUDGET_FORCING;
@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
} }
); );
} }
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens) {
// Determine initial state from prefill: COUNTING if the prefill begins with
// the start sequence but does not also contain the end sequence after it.
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
if (!prefill_tokens.empty() && !start_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() &&
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
initial_state = REASONING_BUDGET_COUNTING;
// If the end sequence also follows the start in the prefill, reasoning
// was opened and immediately closed — stay IDLE.
if (!end_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
initial_state = REASONING_BUDGET_IDLE;
}
}
}
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}

View File

@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
// DONE: passthrough forever // DONE: passthrough forever
// //
// Parameters: // Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) // vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting // start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation // end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires // forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block // budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) // prefill_tokens - tokens already present in the prompt (generation prompt);
// note: COUNTING with budget <= 0 is promoted to FORCING // used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
// //
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
struct llama_sampler * common_reasoning_budget_init( struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens, const std::vector<llama_token> & start_tokens,

View File

@ -1,13 +1,16 @@
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "ggml.h"
#include "log.h" #include "log.h"
#include "reasoning-budget.h" #include "reasoning-budget.h"
#include <algorithm> #include <algorithm>
#include <cctype>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <unordered_map> #include <unordered_map>
#include <vector>
// the ring buffer works similarly to std::deque, but with a fixed capacity // the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h // TODO: deduplicate with llama-impl.h
@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
std::vector<llama_sampler *> samplers; std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { const std::string & grammar_str = common_grammar_value(params.grammar);
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
#else #else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
trigger_patterns_c.push_back(regex.c_str()); trigger_patterns_c.push_back(regex.c_str());
} }
if (!params.grammar.empty()) { if (!grammar_str.empty()) {
if (params.grammar_lazy) { if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()); trigger_tokens.data(), trigger_tokens.size());
} else { } else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
} }
} }
} }
// Feed generation prompt tokens to the grammar sampler so it advances past
// tokens the template already placed in the prompt.
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
std::vector<llama_token> prefill_tokens;
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
GGML_ASSERT(vocab != nullptr);
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
if (!prefill_tokens.empty()) {
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
// Some tokenizers will add a space before the first special token, need to remove
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
}
}
if (grmr) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
}
} catch (std::exception &e) {
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
throw e;
}
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers // reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init( samplers.push_back(common_reasoning_budget_init(
@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
params.reasoning_budget_end, params.reasoning_budget_end,
params.reasoning_budget_forced, params.reasoning_budget_forced,
params.reasoning_budget_tokens, params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); prefill_tokens));
} }
if (params.has_logit_bias()) { if (params.has_logit_bias()) {

View File

@ -31,10 +31,10 @@ import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab from gguf.vocab import MistralTokenizerType, MistralVocab
try: try:
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer, SentencePieceTokenizer,
) )
@ -45,9 +45,9 @@ except ImportError:
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
_mistral_common_installed = False _mistral_common_installed = False
TokenizerVersion = None TokenizerVersion: Any = None
Tekkenizer = None Tekkenizer: Any = None
SentencePieceTokenizer = None SentencePieceTokenizer: Any = None
_mistral_import_error_msg = ( _mistral_import_error_msg = (
"Mistral format requires `mistral-common` to be installed. Please run " "Mistral format requires `mistral-common` to be installed. Please run "
"`pip install mistral-common[image,audio]` to install it." "`pip install mistral-common[image,audio]` to install it."
@ -145,6 +145,7 @@ class ModelBase:
self.model_name = model_name self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self._is_nvfp4 = False self._is_nvfp4 = False
self._is_mxfp4 = False
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype # Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@ -220,7 +221,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict): if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}") raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys()) tensor_names_from_index.update(weight_map.keys())
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
part_names = sorted(part_dict.keys()) part_names = sorted(part_dict.keys())
else: else:
weight_map = {} weight_map = {}
@ -712,6 +713,7 @@ class ModelBase:
def prepare_tensors(self): def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format) # detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo") quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {} quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
quant_config_file = self.dir_model / "hf_quant_config.json" quant_config_file = self.dir_model / "hf_quant_config.json"
@ -728,6 +730,7 @@ class ModelBase:
quant_algo = "NVFP4" quant_algo = "NVFP4"
self._is_nvfp4 = quant_algo == "NVFP4" self._is_nvfp4 = quant_algo == "NVFP4"
self._is_mxfp4 = quant_method == "mxfp4"
# NVFP4 weights are repacked and written directly to gguf_writer. # NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed # This must run before dequant_model so NVFP4 tensors are removed
@ -876,6 +879,12 @@ class ModelBase:
if self.metadata.name is None: if self.metadata.name is None:
self.metadata.name = self.dir_model.name self.metadata.name = self.dir_model.name
if self.ftype in (gguf.LlamaFileType.ALL_F32, gguf.LlamaFileType.MOSTLY_F16, gguf.LlamaFileType.MOSTLY_BF16):
if self._is_nvfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_NVFP4
elif self._is_mxfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
# Generate parameter weight class (useful for leader boards) if not yet determined # Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None and total_params > 0: if self.metadata.size_label is None and total_params > 0:
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count) self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
@ -1062,6 +1071,10 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}") logger.info(f"gguf: key-value head count = {n_head_kv}")
if self.hparams.get("is_causal") is False:
self.gguf_writer.add_causal_attention(False)
logger.info("gguf: causal attention = False")
# TODO: Handle "sliding_attention" similarly when models start implementing it # TODO: Handle "sliding_attention" similarly when models start implementing it
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
if (rope_type := rope_params.get("rope_type")) is not None: if (rope_type := rope_params.get("rope_type")) is not None:
@ -4260,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel") @ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel): class InternVisionModel(MmprojModel):
min_dynamic_tiles: int = 0
max_dynamic_tiles: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0)
self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0)
def set_gguf_parameters(self): def set_gguf_parameters(self):
assert self.hparams_vision is not None assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list): if isinstance(self.hparams_vision['image_size'], list):
@ -4282,6 +4305,11 @@ class InternVisionModel(MmprojModel):
downsample_ratio = self.global_config.get("downsample_ratio") downsample_ratio = self.global_config.get("downsample_ratio")
assert downsample_ratio is not None assert downsample_ratio is not None
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
# older models may not have min/max_dynamic_patch in config
if self.min_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles)
if self.max_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles)
def tensor_force_quant(self, name, new_name, bid, n_dims): def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name: if ".position_embd." in new_name:
@ -5878,7 +5906,7 @@ class InternLM2Model(TextModel):
logger.error(f'Error: Missing {tokenizer_path}') logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1) sys.exit(1)
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
@ -6199,7 +6227,7 @@ class BertModel(TextModel):
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
else: else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
@ -8876,7 +8904,7 @@ class T5Model(TextModel):
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram # some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -9013,7 +9041,7 @@ class T5EncoderModel(TextModel):
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram # some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -11121,8 +11149,7 @@ class GptOssModel(TextModel):
# TODO: remove once MXFP4 is supported more generally # TODO: remove once MXFP4 is supported more generally
def dequant_model(self): def dequant_model(self):
quant_config = self.hparams.get("quantization_config") if self._is_mxfp4:
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
return return
return super().dequant_model() return super().dequant_model()
@ -12275,6 +12302,7 @@ class LazyTorchTensor(gguf.LazyBase):
kwargs = {} kwargs = {}
if func is torch.Tensor.numpy: if func is torch.Tensor.numpy:
assert len(args)
return args[0].numpy() return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs) return cls._wrap_fn(func)(*args, **kwargs)

View File

@ -112,11 +112,11 @@ class Tensor:
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12]) (n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
assert name_len < 4096, 'Absurd tensor name length' assert name_len < 4096, 'Absurd tensor name length'
quant = gguf.GGML_QUANT_SIZES.get(dtype) self.dtype = gguf.GGMLQuantizationType(dtype)
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
assert quant is not None, 'Unknown tensor type' assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant (blksize, tysize) = quant
offset += 12 offset += 12
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len]) self.name = bytes(data[offset:offset + name_len])

View File

@ -199,10 +199,13 @@ class LoraTorchTensor:
kwargs = {} kwargs = {}
if func is torch.permute: if func is torch.permute:
assert len(args)
return type(args[0]).permute(*args, **kwargs) return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape: elif func is torch.reshape:
assert len(args)
return type(args[0]).reshape(*args, **kwargs) return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack: elif func is torch.stack:
assert len(args)
assert isinstance(args[0], Sequence) assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0) dim = kwargs.get("dim", 0)
assert dim == 0 assert dim == 0
@ -211,6 +214,7 @@ class LoraTorchTensor:
torch.stack([b._lora_B for b in args[0]], dim), torch.stack([b._lora_B for b in args[0]], dim),
) )
elif func is torch.cat: elif func is torch.cat:
assert len(args)
assert isinstance(args[0], Sequence) assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0) dim = kwargs.get("dim", 0)
assert dim == 0 assert dim == 0
@ -362,7 +366,7 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
class LoraModel(model_class): class LoraModel(model_class): # ty: ignore[unsupported-base]
model_arch = model_class.model_arch model_arch = model_class.model_arch
lora_alpha: float lora_alpha: float

View File

@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
**Analysis + Parser Building in Two Steps**: **Analysis + Parser Building in Two Steps**:
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs 1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar 2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
## Data Structures ## Data Structures
@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
### `analyze_tools` and its sub-structs ### `analyze_tools` and its sub-structs
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`) - [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names - [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator` - [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values - [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
| Value | Description | | Value | Description |
|-----------------|-----------------------------------------------------------------------------------| |-----------------|-----------------------------------------------------------------------------------|
| `NONE` | No reasoning markers detected | | `NONE` | No reasoning markers detected |
| `TAG_BASED` | Standard tag-based: `<think>...</think>` | | `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content | | `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
**`content_mode`**: How the template wraps assistant content. **`content_mode`**: How the template wraps assistant content.
| Value | Description | | Value | Description |
@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
- Searches `diff.right` (output with reasoning) for the reasoning content needle - Searches `diff.right` (output with reasoning) for the reasoning content needle
- Uses PEG parsers to find surrounding markers: - Uses PEG parsers to find surrounding markers:
- If both pre/post markers found in `diff.right``TAG_BASED` (both tags visible in diff = no forced close) - If both pre/post markers found in `diff.right``TAG_BASED`
- If both found but post marker only in the full output B → `FORCED_CLOSED` - If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
- If only post marker found → `DELIMITER` - If only post marker found → `TAG_BASED` (delimiter-style, empty start)
- Sets `reasoning.start` and `reasoning.end` - Sets `reasoning.start` and `reasoning.end`
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt. **R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN` - Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker → sets `reasoning.start`, mode = `TAG_BASED`
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker - Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers - The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls. **R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
@ -343,7 +352,7 @@ Classification logic:
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds: A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected 1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>` 2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set 3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers 4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
#### Reasoning Parser (`analyze_reasoning::build_parser`) #### Reasoning Parser (`analyze_reasoning::build_parser`)
| Mode | Parser | | Mode | Parser |
|-----------------------------------|---------------------------------------------------------------------| |-----------------------------------------------|---------------------------------------------------------------------------|
| Not extracting reasoning | `eps()` | | Not extracting reasoning | `eps()` |
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt | | `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` | | `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
#### Content Parser (`analyze_content::build_parser`) #### Content Parser (`analyze_content::build_parser`)
@ -410,9 +420,7 @@ All three tool parsers return:
reasoning + optional(content(until(trigger_marker))) + tool_calls + end() reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
``` ```
### Python Dict Format Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
## Mapper ## Mapper
@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments` - **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching - **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format) - **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`) - **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically - **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
## Files ## Files
| File | Purpose | | File | Purpose |
|-------------------------------------------|----------------------------------------------------------------------| |-------------------------------------------|---------------------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` | | `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods | | `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds | | `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, | | `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
| | `compare_variants()`, string helpers | | | `wrap_for_generation_prompt()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers | | `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` | | `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis | | `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool | | `tools/parser/template-analysis.cpp` | Template analysis tool |
## Testing & Debugging ## Testing & Debugging
@ -516,10 +524,10 @@ To support a new template format:
## Edge Cases and Quirks ## Edge Cases and Quirks
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker. 1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker. 2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built. 3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction. 4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case. 5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`. 6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats. 7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.

View File

@ -28,6 +28,9 @@ Additionally, there the following images, similar to the above:
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).

View File

@ -12,9 +12,9 @@ Legend:
- 🟡 Partially supported by this backend - 🟡 Partially supported by this backend
- ❌ Not supported by this backend - ❌ Not supported by this backend
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN | | Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------| |-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | ABS | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -23,63 +23,63 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | COS | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | ELU | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | GELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | GELU_ERF | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | | GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | HARDSWISH | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ | | L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | | LOG | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | | MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ | | MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | NEG | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ | | NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 | | OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | | PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | RELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -91,31 +91,31 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SGN | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SIGMOID | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | SQR | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | SQRT | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ | | SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | STEP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | | SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ | | TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
return f'({result})?' if min_items == 0 else result return f'({result})?' if min_items == 0 else result
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
has_min = min_value != None
has_max = max_value != None
def digit_range(from_char: str, to_char: str): def digit_range(from_char: str, to_char: str):
out.append("[") out.append("[")
if from_char == to_char: if from_char == to_char:
@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
out.append(to_str[i]) out.append(to_str[i])
out.append("]") out.append("]")
if has_min and has_max: if min_value is not None and max_value is not None:
if min_value < 0 and max_value < 0: if min_value < 0 and max_value < 0:
out.append("\"-\" (") out.append("\"-\" (")
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
less_decimals = max(decimals_left - 1, 1) less_decimals = max(decimals_left - 1, 1)
if has_min: if min_value is not None:
if min_value < 0: if min_value < 0:
out.append("\"-\" (") out.append("\"-\" (")
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
more_digits(length - 1, less_decimals) more_digits(length - 1, less_decimals)
return return
if has_max: if max_value is not None:
if max_value >= 0: if max_value >= 0:
if top_level: if top_level:
out.append("\"-\" [1-9] ") out.append("\"-\" [1-9] ")

View File

@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print("Using SentenceTransformer to apply all numbered layers") print("Using SentenceTransformer to apply all numbered layers")
model = SentenceTransformer(model_path) model = SentenceTransformer(model_path)
tokenizer = model.tokenizer tokenizer = model.tokenizer
config = model[0].auto_model.config # type: ignore config = model[0].auto_model.config
else: else:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print(f"Model file: {type(model).__module__}") print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window # Verify the model is using the correct sliding window
if hasattr(model.config, 'sliding_window'): # type: ignore if hasattr(model.config, 'sliding_window'):
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore print(f"Model's sliding_window: {model.config.sliding_window}")
else: else:
print("Model config does not have sliding_window attribute") print("Model config does not have sliding_window attribute")
@ -152,7 +152,7 @@ def main():
device = next(model.parameters()).device device = next(model.parameters()).device
else: else:
# For SentenceTransformer, get device from the underlying model # For SentenceTransformer, get device from the underlying model
device = next(model[0].auto_model.parameters()).device # type: ignore device = next(model[0].auto_model.parameters()).device
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
@ -177,7 +177,7 @@ def main():
print(f"{token_id:6d} -> '{token_str}'") print(f"{token_id:6d} -> '{token_str}'")
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
else: else:
# Standard approach: use base model output only # Standard approach: use base model output only
encoded = tokenizer( encoded = tokenizer(
@ -205,12 +205,12 @@ def main():
print(f"Embedding dimension: {all_embeddings.shape[1]}") print(f"Embedding dimension: {all_embeddings.shape[1]}")
if len(all_embeddings.shape) == 1: if len(all_embeddings.shape) == 1:
n_embd = all_embeddings.shape[0] # type: ignore n_embd = all_embeddings.shape[0]
n_embd_count = 1 n_embd_count = 1
all_embeddings = all_embeddings.reshape(1, -1) all_embeddings = all_embeddings.reshape(1, -1)
else: else:
n_embd = all_embeddings.shape[1] # type: ignore n_embd = all_embeddings.shape[1]
n_embd_count = all_embeddings.shape[0] # type: ignore n_embd_count = all_embeddings.shape[0]
print() print()

View File

@ -2,7 +2,7 @@
import argparse import argparse
import sys import sys
from common import compare_tokens # type: ignore from common import compare_tokens # type: ignore[import-not-found]
def parse_arguments(): def parse_arguments():

View File

@ -6,7 +6,7 @@ import re
from copy import copy from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a type annotation # Assert that the parameter has a type annotation
if param.annotation == inspect.Parameter.empty: if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation") raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
# Find the parameter's description in the docstring # Find the parameter's description in the docstring
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a description # Assert that the parameter has a description
if not param_doc or not param_doc.description: if not param_doc or not param_doc.description:
raise ValueError( raise ValueError(
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring") f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
# Add parameter details to the schema # Add parameter details to the schema
param_docs.append((param.name, param_doc)) param_docs.append((param.name, param_doc))
@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
dynamic_fields[param.name] = ( dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model # Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
for name, param_doc in param_docs: for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description dynamic_model.model_fields[name].description = param_doc.description
@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
if items != {}: if items != {}:
array = {"properties": items} array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items") array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
fields[field_name] = (List[array_type], ...) fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
else: else:
fields[field_name] = (list, ...) fields[field_name] = (list, ...)
elif field_type == "object": elif field_type == "object":

View File

@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version ### GGML Version
set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9) set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 7) set(GGML_VERSION_PATCH 8)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)

View File

@ -733,6 +733,10 @@ extern "C" {
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");
GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op);

View File

@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
bli_thread_set_num_threads(ctx->n_threads); bli_thread_set_num_threads(ctx->n_threads);
#elif defined(GGML_BLAS_USE_NVPL) #elif defined(GGML_BLAS_USE_NVPL)
nvpl_blas_set_num_threads(ctx->n_threads); nvpl_blas_set_num_threads(ctx->n_threads);
#elif defined(GGML_BLAS_USE_MKL)
mkl_set_num_threads(ctx->n_threads);
#endif #endif
for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i13 = 0; i13 < ne13; i13++) {

View File

@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
end = 2 * ((n_head - 1) - n_head_log2) + 1; end = 2 * ((n_head - 1) - n_head_log2) + 1;
step = 2; step = 2;
count = n_head - n_head_log2; count = n_head - n_head_log2;
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
dtype); step, dtype);
} }
} }
@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src1 = dst->src[1]; // index ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_BF16:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
if (src0->type == dst->type) { if (src0->type == dst->type) {
@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break; break;
} }
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{ {
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
} }
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type); dst->type);
@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) { if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else { } else {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
ggml_cann_mat_mul_fp(ctx, dst); ggml_cann_mat_mul_fp(ctx, dst);
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -2943,6 +2949,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// Rotate full tensor (no tail), using trans tensors // Rotate full tensor (no tail), using trans tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
} else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
// In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
// read and write the same non-contiguous buffer. Use contiguous temporaries.
size_t contiguous_nb[GGML_MAX_DIMS];
contiguous_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
}
int64_t total_elements = ggml_nelements(src0);
ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
src0->ne, contiguous_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
dst->ne, contiguous_nb, GGML_MAX_DIMS);
cann_copy(ctx, acl_src.get(), acl_src_contig.get());
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
} else { } else {
// Rotate full tensor (no tail), using original tensors // Rotate full tensor (no tail), using original tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
@ -2984,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
} }
} }
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_imrope || mrope_used) {
is_neox = true;
}
int64_t rope_dims = n_dims;
if (is_vision) {
rope_dims = src0->ne[0];
}
// Run the full cache init on the non-captured stream. This performs all
// host-to-device memcpy, aclrtMalloc/Free, and on-device computations
// so that the memory pool is warmed up and cache metadata is populated.
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
mrope_used, is_imrope, is_vision, rope_dims);
// Reset `cached` so that during graph capture the on-device computations
// (sin/cos, position multiply, repeat, etc.) still execute and get recorded
// into the captured graph. The cache metadata (theta_scale_length,
// theta_scale, sections, position_length, etc.) remains set, which causes
// all host-to-device copy and malloc/free branches to be skipped.
ctx.rope_cache.cached = false;
}
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; ggml_tensor * src0 = dst->src[0];
@ -3599,6 +3678,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
// Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
// (required by FusedInferAttentionScoreV2)
const int64_t D = src0->ne[0];
const int64_t D_padded = GGML_PAD(D, 16);
const bool needs_padding = (D != D_padded);
ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
if (needs_padding) {
int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
ggml_cann_pool_alloc & allocator) {
int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
size_t pad_nb[GGML_MAX_DIMS];
pad_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
}
int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
void * buffer = allocator.alloc(nelements * faElemSize);
acl_tensor_ptr padded =
ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
tensor = std::move(padded);
};
pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
src0_bsnd_ne[0] = D_padded;
src1_bsnd_ne[0] = D_padded;
src2_bsnd_ne[0] = D_padded;
}
// Step 3: create the PSEShift tensor if needed // Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp // this tensor is considered as mask (f16) in the llama.cpp
acl_tensor_ptr bcast_pse_tensor; acl_tensor_ptr bcast_pse_tensor;
@ -3688,17 +3805,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
acl_tensor_ptr fa_dst_tensor; acl_tensor_ptr fa_dst_tensor;
acl_tensor_ptr acl_dst_tensor;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) { if (dst->type == GGML_TYPE_F32 || needs_padding) {
void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
int64_t * out_f16_ne = src0_bsnd_ne; int64_t * out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS]; size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize; out_f16_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) { for (int i = 1; i < GGML_MAX_DIMS; ++i) {
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
} }
int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
fa_dst_tensor = fa_dst_tensor =
ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@ -3730,8 +3846,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
nullptr // softmaxLse nullptr // softmaxLse
); );
if (dst->type == GGML_TYPE_F32) { // Step 6: post-processing — slice padded output and/or cast to f32
// Step 6: post-processing, permute and cast to f32 if (needs_padding) {
ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
size_t sliced_nb[GGML_MAX_DIMS];
sliced_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
}
int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
sliced_ne, sliced_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
} else {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
}
} else if (dst->type == GGML_TYPE_F32) {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
} }

View File

@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/ */
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Pre-load the RoPE cache before ACL graph capture.
*
* This function must be called outside of graph capture to perform
* host-to-device memory copies and device memory allocations that are
* not allowed on a captured stream. After pre-loading, the rope cache
* metadata is updated so that the subsequent call to
* aclnn_rope_cache_init (inside graph capture) skips these operations
* and only records the on-device computations into the captured graph.
*
* @param ctx CANN backend context.
* @param dst A ROPE destination tensor from the computation graph.
*/
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/** /**
* @brief Computes the index of the maximum value along the specified dimension * @brief Computes the index of the maximum value along the specified dimension
* of a ggml tensor using the CANN backend. * of a ggml tensor using the CANN backend.

View File

@ -277,7 +277,7 @@ struct ggml_graph_node_properties {
} }
} }
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
} }
return true; return true;

View File

@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) { if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1); GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset, ctx->device); weight_format_to_nz(tensor, offset, ctx->device);
@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
if (ne0 % MATRIX_ROW_PADDING != 0) { if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
} }
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { } else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
// NZ format weight are not support quantized yet. // NZ format weight are not support quantized yet.
// If ND tensor transform to NZ, size may changed. // If ND tensor transform to NZ, size may changed.
int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
@ -2223,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// If no matching graph is found, add a new ACL graph. // If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
cann_ctx->graph_lru_cache.push(new_graph); cann_ctx->graph_lru_cache.push(new_graph);
// Pre-load rope cache before graph capture. During capture the
// stream cannot perform host-to-device memcpy or device memory
// malloc/free. Running the full cache init now populates the
// cache metadata so these branches are skipped during capture,
// while also warming up the memory pool.
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (node->op == GGML_OP_ROPE) {
ggml_cann_rope_cache_preload(*cann_ctx, node);
break;
}
}
} }
} }
#else #else
@ -2283,6 +2298,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
switch (op->src[0]->type) { switch (op->src[0]->type) {
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
return true; return true;
@ -2320,6 +2338,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return true; return true;
default: default:
@ -2332,6 +2353,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->type) { switch (op->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true; return true;
default: default:
return false; return false;
@ -2341,20 +2365,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
ggml_tensor * src = op->src[0]; ggml_tensor * src = op->src[0];
#ifdef ASCEND_310P
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
// only support F32 and F16. // only support F32 and F16 on 310P.
return false; return false;
} }
#else
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
// only support F32, F16 and BF16.
return false;
}
#endif
return true; return true;
} }
break; break;
case GGML_OP_CONT: case GGML_OP_CONT:
{ {
// TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true; return true;
default: default:
return false; return false;
@ -2503,10 +2537,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// different head sizes of K and V are not supported yet // different head sizes of K and V are not supported yet
return false; return false;
} }
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
}
float logitSoftcap = 0.0f; float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
if (logitSoftcap != 0.0f) { if (logitSoftcap != 0.0f) {

View File

@ -570,24 +570,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
if (POLICY CMP0135) set(KLEIDIAI_FETCH_ARGS
cmake_policy(SET CMP0135 NEW) URL ${KLEIDIAI_DOWNLOAD_URL}
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
endif() endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 FetchContent_Declare(KleidiAI_Download
FetchContent_Declare(KleidiAI_Download ${KLEIDIAI_FETCH_ARGS}
URL ${KLEIDIAI_DOWNLOAD_URL} EXCLUDE_FROM_ALL
DOWNLOAD_EXTRACT_TIMESTAMP NEW )
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_GetProperties(KleidiAI_Download FetchContent_MakeAvailable(KleidiAI_Download)
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
else()
FetchContent_Declare(KleidiAI_Download
${KLEIDIAI_FETCH_ARGS}
)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED
)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
endif() endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI) add_compile_definitions(GGML_USE_CPU_KLEIDIAI)

View File

@ -531,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
UNUSED(bs); UNUSED(bs);
__m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
// Permute mask used for easier vector processing at later stages // Permute mask used for easier vector processing at later stages
@ -580,6 +579,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
if constexpr ( if constexpr (
std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_q4_0x8> ||
std::is_same_v<block_tx8, block_iq4_nlx8>) { std::is_same_v<block_tx8, block_iq4_nlx8>) {
const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
// Load 8 E8M0 exponents and convert to float via LUT // Load 8 E8M0 exponents and convert to float via LUT

View File

@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
private: private:
__attribute__((always_inline))
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4]; vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC); __builtin_mma_disassemble_acc(vec_C, ACC);
@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
} }
} }
__attribute__((always_inline))
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4]; vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC); __builtin_mma_disassemble_acc(vec_C, ACC);

View File

@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS}) list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else() else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") list(APPEND GGML_SOURCES_CUDA
list(APPEND GGML_SOURCES_CUDA ${SRCS}) template-instances/fattn-vec-instance-f16-f16.cu
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") template-instances/fattn-vec-instance-q4_0-q4_0.cu
list(APPEND GGML_SOURCES_CUDA ${SRCS}) template-instances/fattn-vec-instance-q8_0-q8_0.cu
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") template-instances/fattn-vec-instance-bf16-bf16.cu)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif() endif()
ggml_add_backend_library(ggml-cuda ggml_add_backend_library(ggml-cuda

View File

@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
return __bfloat162float(x); return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) { } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x); return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
#ifdef GGML_USE_HIP
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#else
#if __CUDA_ARCH__ >= 800
return __bfloat1622float2(x);
#else
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
#endif // __CUDA_ARCH__ >= 800
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) { } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1 // bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP #ifdef GGML_USE_HIP

View File

@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
return sum; return sum;
} }
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
__align__(16) nv_bfloat162 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef V_DOT2_F32_F16_AVAILABLE
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
#else
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
return sum;
}
template<int D, int nthreads> template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
} }
} }
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
static_assert(ne % 2 == 0, "bad ne");
__align__(16) nv_bfloat162 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
}
}
template <typename T, int ne> template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>; return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_Q8_0) { } else if constexpr (type_K == GGML_TYPE_Q8_0) {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>; return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else { } else {
static_assert(type_K == -1, "bad type"); static_assert(type_K == -1, "bad type");
return nullptr; return nullptr;
@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q5_1<T, ne>; return dequantize_V_q5_1<T, ne>;
} else if constexpr (type_V == GGML_TYPE_Q8_0) { } else if constexpr (type_V == GGML_TYPE_Q8_0) {
return dequantize_V_q8_0<T, ne>; return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else { } else {
static_assert(type_V == -1, "bad type"); static_assert(type_V == -1, "bad type");
return nullptr; return nullptr;

View File

@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
#endif // GGML_USE_HIP #endif // GGML_USE_HIP
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
#ifdef V_DOT2_F32_F16_AVAILABLE #ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else #else
@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll #pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
half2 tmp[V_rows_per_thread/2]; half2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp, if constexpr (type_V == GGML_TYPE_BF16) {
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); float2 tmp_f[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp_f,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
}
} else {
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
}
#pragma unroll #pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll #pragma unroll
@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

View File

@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#else #else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
break; break;
default: default:
return BEST_FATTN_KERNEL_NONE; return BEST_FATTN_KERNEL_NONE;

View File

@ -126,7 +126,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
if (err == hipSuccess) { if (err == hipSuccess) {
// hipMemAdviseSetCoarseGrain is an optional performance hint; // hipMemAdviseSetCoarseGrain is an optional performance hint;
// ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
(void)hipGetLastError(); // clear any error (void)hipGetLastError(); // clear any error
} }

View File

@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
} }
} }
static constexpr __device__ int get_vdr_mmvq(ggml_type type) { static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
return 1; return 1;
} }
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1:
return 1; return small_k ? nwarps : 1;
case 2: case 2:
case 3: case 3:
case 4: case 4:
@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1; return 1;
} }
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false> template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
constexpr int vdr = get_vdr_mmvq(type); constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
template<ggml_type type> template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params( static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) { const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); const int nwarps = calc_nwarps(type, ncols_dst, table_id);
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims}; return {block_nums, block_dims};
} }
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false> template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion( static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) { if constexpr (c_ncols_dst == 1) {
if (has_fusion) { if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) { switch (ncols_dst) {
case 1: { case 1: {
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, constexpr int qk = ggml_cuda_type_traits<type>::qk;
dims.first, dims.second, 0, ids_stride, stream); constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
}
} break; } break;
case 2: { case 2: {
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);

View File

@ -5,7 +5,7 @@ import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.

View File

@ -45,6 +45,7 @@ static int opt_verbose = 0;
static int opt_profile = 0; static int opt_profile = 0;
static int opt_hostbuf = 1; // hostbuf ON by default static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_experimental = 0; static int opt_experimental = 0;
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
// Enable all stages by default // Enable all stages by default
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
@ -460,7 +461,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
d[7] = x[i * 8 + 7].d; d[7] = x[i * 8 + 7].d;
} }
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_q4x4x2(y, i, k); dump_packed_block_q4x4x2(y, i, k);
} }
@ -479,7 +480,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
const uint8_t * y_q = y + 0; // quants first const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_d = y + qrow_size; // then scales const uint8_t * y_d = y + qrow_size; // then scales
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_q4x4x2(y, i, k); dump_packed_block_q4x4x2(y, i, k);
} }
@ -795,7 +796,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
d[7] = x[i * 8 + 7].d; d[7] = x[i * 8 + 7].d;
} }
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_q8x4x2(y, i, k); dump_packed_block_q8x4x2(y, i, k);
} }
@ -813,7 +814,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
const uint8_t * y_q = y + 0; // quants first const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_d = y + qrow_size; // then scales const uint8_t * y_d = y + qrow_size; // then scales
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_q8x4x2(y, i, k); dump_packed_block_q8x4x2(y, i, k);
} }
@ -1148,7 +1149,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
e[7] = x[i * 8 + 7].e; e[7] = x[i * 8 + 7].e;
} }
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_mxfp4x4x2(y, i, k); dump_packed_block_mxfp4x4x2(y, i, k);
} }
@ -1167,7 +1168,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
const uint8_t * y_q = y + 0; // quants first const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_e = y + qrow_size; // then scales const uint8_t * y_e = y + qrow_size; // then scales
if (opt_verbose > 1) { if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
dump_packed_block_mxfp4x4x2(y, i, k); dump_packed_block_mxfp4x4x2(y, i, k);
} }
@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
// Start the DSP-side service. We need to pass the queue ID to the // Start the DSP-side service. We need to pass the queue ID to the
// DSP in a FastRPC call; the DSP side will import the queue and start // DSP in a FastRPC call; the DSP side will import the queue and start
// listening for packets in a callback. // listening for packets in a callback.
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
if (err != 0) { if (err != 0) {
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
@ -2362,6 +2363,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs,
return n_bufs; return n_bufs;
} }
static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
// CONT is just a contiguous copy — reuse CPY op
req->op = HTP_OP_CPY;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_REPEAT;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_GET_ROWS; req->op = HTP_OP_GET_ROWS;
@ -2449,12 +2471,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { switch (ggml_get_unary_op(t)) {
case GGML_UNARY_OP_SILU:
req->op = HTP_OP_UNARY_SILU; req->op = HTP_OP_UNARY_SILU;
supported = true; supported = true;
} else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { break;
case GGML_UNARY_OP_GELU:
req->op = HTP_OP_UNARY_GELU; req->op = HTP_OP_UNARY_GELU;
supported = true; supported = true;
break;
case GGML_UNARY_OP_SIGMOID:
req->op = HTP_OP_UNARY_SIGMOID;
supported = true;
break;
case GGML_UNARY_OP_NEG:
req->op = HTP_OP_UNARY_NEG;
supported = true;
break;
case GGML_UNARY_OP_EXP:
req->op = HTP_OP_UNARY_EXP;
supported = true;
break;
case GGML_UNARY_OP_SOFTPLUS:
req->op = HTP_OP_UNARY_SOFTPLUS;
supported = true;
break;
default:
break;
} }
break; break;
@ -2640,16 +2683,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags); ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || switch (ggml_get_unary_op(node)) {
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { case GGML_UNARY_OP_NEG:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
default:
break;
} }
break; break;
case GGML_OP_GLU: case GGML_OP_GLU:
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || switch (ggml_get_glu_op(node)) {
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || case GGML_GLU_OP_SWIGLU:
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { case GGML_GLU_OP_SWIGLU_OAI:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); case GGML_GLU_OP_GEGLU:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
default:
break;
} }
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -2676,6 +2731,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags); ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break; break;
case GGML_OP_CONT:
ggml_hexagon_dispatch_op<init_cont_req>(sess, node, flags);
break;
case GGML_OP_REPEAT:
ggml_hexagon_dispatch_op<init_repeat_req>(sess, node, flags);
break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags); ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break; break;
@ -3006,6 +3069,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess,
return true; return true;
} }
static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
GGML_UNUSED(sess);
const struct ggml_tensor * src0 = op->src[0];
// CONT is same-type only, supports f32 and f16
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
return true;
}
static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
GGML_UNUSED(sess);
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
// Support f32 and f16
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
// src and dst must be the same type
if (src0->type != dst->type) return false;
// dst dims must be multiples of src dims
if (dst->ne[0] % src0->ne[0] != 0) return false;
if (dst->ne[1] % src0->ne[1] != 0) return false;
if (dst->ne[2] % src0->ne[2] != 0) return false;
if (dst->ne[3] % src0->ne[3] != 0) return false;
// require contiguous tensors (no transposition)
if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false;
return true;
}
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
auto sess = static_cast<ggml_hexagon_session *>(dev->context); auto sess = static_cast<ggml_hexagon_session *>(dev->context);
@ -3063,21 +3159,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
{ switch (ggml_get_unary_op(op)) {
const auto unary_op = ggml_get_unary_op(op); case GGML_UNARY_OP_NEG:
if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_SOFTPLUS:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
supp = ggml_hexagon_supported_activations(sess, op); supp = ggml_hexagon_supported_activations(sess, op);
} break;
break; default:
break;
} }
break;
case GGML_OP_GLU: case GGML_OP_GLU:
{ switch (ggml_get_glu_op(op)) {
const auto glu_op = ggml_get_glu_op(op); case GGML_GLU_OP_SWIGLU:
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU:
supp = ggml_hexagon_supported_activations(sess, op); supp = ggml_hexagon_supported_activations(sess, op);
} break;
break; default:
break;
} }
break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
supp = ggml_hexagon_supported_rope(sess, op); supp = ggml_hexagon_supported_rope(sess, op);
break; break;
@ -3098,6 +3205,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_cpy(sess, op); supp = ggml_hexagon_supported_cpy(sess, op);
break; break;
case GGML_OP_CONT:
supp = ggml_hexagon_supported_cont(sess, op);
break;
case GGML_OP_REPEAT:
supp = ggml_hexagon_supported_repeat(sess, op);
break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op); supp = ggml_hexagon_supported_argsort(sess, op);
break; break;
@ -3258,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_etm = getenv("GGML_HEXAGON_ETM");
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH"); const char * str_arch = getenv("GGML_HEXAGON_ARCH");
@ -3267,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
opt_opsync = str_opsync ? atoi(str_opsync) : 0; opt_opsync = str_opsync ? atoi(str_opsync) : 0;
opt_profile = str_profile ? atoi(str_profile) : 0; opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0; opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {

View File

@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED
set-rows-ops.c set-rows-ops.c
get-rows-ops.c get-rows-ops.c
cpy-ops.c cpy-ops.c
repeat-ops.c
argsort-ops.c argsort-ops.c
ssm-conv.c ssm-conv.c
) )
@ -39,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,> $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB}) build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)

View File

@ -24,28 +24,26 @@
// Context for binary operations // Context for binary operations
struct htp_binary_context { struct htp_binary_context {
struct htp_ops_context * octx; struct htp_ops_context * octx;
struct fastdiv_values dim1_div;
struct fastdiv_values dim2_div; struct fastdiv_values src0_dim1_div; // ne01
struct fastdiv_values dim12_div; struct fastdiv_values src0_dim2_div; // ne02
struct fastdiv_values src0_dim12_div;// ne03
struct fastdiv_values src1_dim1_div; // ne11 struct fastdiv_values src1_dim1_div; // ne11
struct fastdiv_values src1_dim2_div; // ne12 struct fastdiv_values src1_dim2_div; // ne12
struct fastdiv_values src1_dim3_div; // ne13 struct fastdiv_values src1_dim3_div; // ne13
uint32_t nrows_per_thread;
bool split_at_ne01;
bool split_at_ne02;
// Precomputed values
uint32_t block_max; uint32_t block_max;
uint32_t nrows_per_thread;
size_t src0_row_size_aligned; size_t src0_row_size_aligned;
size_t src1_row_size_aligned; size_t src1_row_size_aligned;
size_t dst_row_size_aligned; size_t dst_row_size_aligned;
uint32_t src1_fetch_rows; // 1 or block_max
uint32_t src1_dma_stride; // 0 or stride bool split_at_ne01;
bool split_at_ne02;
}; };
#define htp_binary_preamble \ #define htp_binary_preamble \
const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src0 = &octx->src0; \
const struct htp_tensor * src1 = &octx->src1; \ const struct htp_tensor * src1 = &octx->src1; \
struct htp_tensor * dst = &octx->dst; \ struct htp_tensor * dst = &octx->dst; \
@ -72,12 +70,11 @@ struct htp_binary_context {
const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; const uint32_t nb3 = dst->nb[3];
static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) {
uint32_t ne01, uint32_t ne02) {
uint32_t i03, i02, i01, rem; uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div); i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01); rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div); i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01; i01 = rem - i02 * ne01;
uint32_t rows_left = end_row - ir; uint32_t rows_left = end_row - ir;
@ -191,6 +188,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return; if (start_row >= end_row) return;
FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01); rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div); i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01; i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem; uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div); i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01); rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div); i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01; i01 = rem - i02 * ne01;
// src1 indices (broadcast/repeat) // src1 indices (broadcast/repeat)
@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01); prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div); p02 = fastdiv(prem, &bctx->src0_dim1_div);
p01 = prem - p02 * ne01; p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
@ -282,6 +281,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return; if (start_row >= end_row) return;
FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01); rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div); i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01; i01 = rem - i02 * ne01;
uint32_t i13 = (ne13 == 1) ? 0 : i03; uint32_t i13 = (ne13 == 1) ? 0 : i03;
@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint32_t i11 = (ne11 == 1) ? 0 : i01; uint32_t i11 = (ne11 == 1) ? 0 : i01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
} }
for (uint32_t ir = start_row; ir < end_row; ) { for (uint32_t ir = start_row; ir < end_row; ) {
uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
} }
uint32_t i03, i02, i01, rem; uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div); i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01); rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div); i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01; i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01); prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div); p02 = fastdiv(prem, &bctx->src0_dim1_div);
p01 = prem - p02 * ne01; p01 = prem - p02 * ne01;
uint32_t p13 = (ne13 == 1) ? 0 : p03; uint32_t p13 = (ne13 == 1) ? 0 : p03;
@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size);
ir_prefetch += next_block_size; ir_prefetch += next_block_size;
} }
@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
struct htp_ops_context * octx = bctx->octx; struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble; htp_binary_preamble;
const uint32_t src0_type = octx->src0.type; const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return; if (start_row >= end_row) return;
FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint32_t ir_prefetch = start_row; uint32_t ir_prefetch = start_row;
int spad_idx = 0; int spad_idx = 0;
void * s1_ptr = (void *) src1_spad; void * s1_ptr = (void *) src1_spad_base;
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
rem = ir_prefetch - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
for (uint32_t ir = start_row; ir < end_row; ) { for (uint32_t ir = start_row; ir < end_row; ) {
uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
for (uint32_t r = 0; r < current_block_size; r++) { for (uint32_t r = 0; r < current_block_size; r++) {
@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
} }
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
i03 = fastdiv(ir, &bctx->dim12_div); uint32_t rem = ir - i03 * (ne02 * ne01);
rem = ir - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
prem = ir_prefetch - p03 * (ne02 * ne01); uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
p02 = fastdiv(prem, &bctx->dim1_div); uint32_t p01 = prem - p02 * ne01;
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size; ir_prefetch += next_block_size;
@ -458,14 +458,16 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
const uint32_t src0_type = octx->src0.type; const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return; if (start_row >= end_row) return;
FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
dma_queue * q = octx->ctx->dma[ith]; dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row; uint32_t ir_prefetch = start_row;
@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
rem = ir_prefetch - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
i03 = fastdiv(ir, &bctx->dim12_div); uint32_t rem = ir - i03 * (ne02 * ne01);
rem = ir - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) { for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r; uint32_t r_i01 = i01 + r;
@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
prem = ir_prefetch - p03 * (ne02 * ne01); uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
p02 = fastdiv(prem, &bctx->dim1_div); uint32_t p01 = prem - p02 * ne01;
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size; ir_prefetch += next_block_size;
@ -545,14 +544,16 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return; if (start_row >= end_row) return;
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
dma_queue * q = octx->ctx->dma[ith]; dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row; uint32_t ir_prefetch = start_row;
@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
rem = ir_prefetch - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
i03 = fastdiv(ir, &bctx->dim12_div); uint32_t rem = ir - i03 * (ne02 * ne01);
rem = ir - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) { for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r; uint32_t r_i01 = i01 + r;
@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
prem = ir_prefetch - p03 * (ne02 * ne01); uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
p02 = fastdiv(prem, &bctx->dim1_div); uint32_t p01 = prem - p02 * ne01;
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size; ir_prefetch += next_block_size;
@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nb02 = src0->nb[2]; const uint32_t nb02 = src0->nb[2];
const uint32_t nb03 = src0->nb[3]; const uint32_t nb03 = src0->nb[3];
const uint32_t nb11 = src1->nb[1]; // src1 row stride const uint32_t nb11 = src1->nb[1]; // src1 row stride
const uint32_t nb1 = dst->nb[1]; const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2]; const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3]; const uint32_t nb3 = dst->nb[3];
@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
dma_queue * q = octx->ctx->dma[ith]; dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row; uint32_t ir_prefetch = start_row;
@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
rem = ir_prefetch - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
ir_prefetch += current_block_size; ir_prefetch += current_block_size;
spad_idx ^= 1; spad_idx ^= 1;
@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem; uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
i03 = fastdiv(ir, &bctx->dim12_div); uint32_t rem = ir - i03 * (ne02 * ne01);
rem = ir - i03 * (ne02 * ne01); uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
i02 = fastdiv(rem, &bctx->dim1_div); uint32_t i01 = rem - i02 * ne01;
i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) { for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
if (ir_prefetch < end_row) { if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem; uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->dim12_div); uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
prem = ir_prefetch - p03 * (ne02 * ne01); uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
p02 = fastdiv(prem, &bctx->dim1_div); uint32_t p01 = prem - p02 * ne01;
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
ir_prefetch += next_block_size; ir_prefetch += next_block_size;
@ -739,40 +735,36 @@ static int execute_op_binary(struct htp_ops_context * octx) {
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src0_row_size = src0->ne[0] * elem_size;
const size_t src1_row_size = src1->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size;
const size_t dst_row_size = dst->ne[0] * elem_size; const size_t dst_row_size = dst->ne[0] * elem_size;
// Align to VLEN size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
bool is_add_id = (octx->op == HTP_OP_ADD_ID); bool is_add_id = (octx->op == HTP_OP_ADD_ID);
bool is_scalar = !is_add_id && (src1->ne[0] == 1); bool is_scalar = !is_add_id && (src1->ne[0] == 1);
// Determine which kernel we will use to alloc memory and dispatch bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size);
bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
bool is_same_shape = !is_add_id && !is_scalar && !is_transposed &&
(src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) &&
(src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
(src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
(src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]);
bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]);
size_t spad_row_total; size_t spad_row_total;
if (is_scalar) { if (is_same_shape) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
} else if (is_row_bcast) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
} else if (use_vector_same) {
spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
} else if (is_add_id) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
} else { } else {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
} }
size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
// Adjust for static src1 in row_bcast case // Adjust for static src1 in row_bcast case
if (is_row_bcast) { if (is_row_bcast) {
size_t needed_static = src1_row_size_aligned; size_t needed_static = src1_row_size_aligned;
@ -782,28 +774,26 @@ static int execute_op_binary(struct htp_ops_context * octx) {
} }
if (rows_per_buffer < 1) { if (rows_per_buffer < 1) {
FARF(ERROR, "binary: VTCM too small\n"); FARF(ERROR, "binary: VTCM too small\n");
return HTP_STATUS_VTCM_TOO_SMALL; return HTP_STATUS_VTCM_TOO_SMALL;
} }
octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
if (is_scalar || use_complex || use_repeat || is_add_id) { if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) {
octx->src1_spad.size_per_thread = 0;
} else if (is_row_bcast) {
octx->src1_spad.size_per_thread = 0; octx->src1_spad.size_per_thread = 0;
} else { } else {
octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
} }
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
if (is_row_bcast) { if (is_row_bcast) {
octx->src1_spad.size = src1_row_size_aligned; octx->src1_spad.size = src1_row_size_aligned;
} else { } else {
octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
} }
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
return HTP_STATUS_VTCM_TOO_SMALL; return HTP_STATUS_VTCM_TOO_SMALL;
@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) {
} }
struct htp_binary_context bctx; struct htp_binary_context bctx;
bctx.octx = octx; bctx.octx = octx;
bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bctx.block_max = rows_per_buffer; bctx.block_max = rows_per_buffer;
bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src0_row_size_aligned = src0_row_size_aligned;
bctx.src1_row_size_aligned = src1_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned;
bctx.dst_row_size_aligned = dst_row_size_aligned; bctx.dst_row_size_aligned = dst_row_size_aligned;
bctx.dim1_div = init_fastdiv_values(src0->ne[1]); bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]);
bctx.dim2_div = init_fastdiv_values(src0->ne[2]); bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]);
bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
bctx.split_at_ne01 = (src0->ne[2] > 1) && bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
bctx.split_at_ne02 = (src0->ne[3] > 1) &&
((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
// Precompute specific kernel parameters
if (use_vector_same) {
bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
}
worker_callback_t worker_func; worker_callback_t worker_func;
if (is_add_id) worker_func = binary_job_add_id; if (is_add_id) worker_func = binary_job_add_id;
else if (is_scalar) worker_func = binary_job_scalar; else if (is_scalar) worker_func = binary_job_scalar;
else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
else if (use_vector_same) worker_func = binary_job_vector_same_shape; else if (is_same_shape) worker_func = binary_job_vector_same_shape;
else if (use_complex) worker_func = binary_job_vector_complex; else if (is_complex) worker_func = binary_job_vector_complex;
else worker_func = binary_job_element_repeat; else worker_func = binary_job_element_repeat;
if (is_row_bcast) { if (is_row_bcast) {
dma_queue_pop(q); dma_queue_pop(q);

View File

@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) {
q->capacity = capacity; q->capacity = capacity;
q->idx_mask = capacity - 1; q->idx_mask = capacity - 1;
q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d));
memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d));
q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr));
memset(q->dptr, 0, capacity * sizeof(dma_ptr)); memset(q->dptr, 0, capacity * sizeof(dma_ptr));

View File

@ -10,19 +10,84 @@
extern "C" { extern "C" {
#endif #endif
// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date
typedef struct dma_descriptor_1d_s {
void * next;
uint32_t size:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
} dma_descriptor_1d;
#if __HVX_ARCH__ < 75
typedef struct dma_descriptor_2d_s {
void * next;
uint32_t reserved0:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
uint32_t desc_type:8;
uint32_t reserved1:24;
uint32_t row_size:16;
uint32_t nrows:16;
uint32_t src_stride:16;
uint32_t dst_stride:16;
uint32_t src_offset:16;
uint32_t dst_offset:16;
} dma_descriptor_2d;
#else
typedef struct dma_descriptor_2d_s {
void * next;
uint32_t dst_stride:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
uint32_t desc_type:8;
uint32_t reserved0:24;
uint32_t row_size:24;
uint32_t nrows_lo:8;
uint32_t nrows_hi:8;
uint32_t src_stride:24;
uint32_t offset:24;
uint32_t reserved1:8;
} dma_descriptor_2d;
#endif
typedef struct { typedef struct {
void *dst; void *dst;
const void *src; const void *src;
} dma_ptr; } dma_ptr;
typedef struct { typedef struct {
hexagon_udma_descriptor_type1_t * desc; // descriptor pointers dma_descriptor_2d * desc; // descriptor pointers
hexagon_udma_descriptor_type1_t * tail; // tail pointer dma_descriptor_2d * tail; // tail pointer
dma_ptr * dptr; // dst/src pointers dma_ptr * dptr; // dst/src pointers
uint32_t push_idx; uint32_t push_idx;
uint32_t pop_idx; uint32_t pop_idx;
uint32_t capacity; uint32_t capacity;
uint32_t idx_mask; uint32_t idx_mask;
} dma_queue; } dma_queue;
dma_queue * dma_queue_create(size_t capacity); dma_queue * dma_queue_create(size_t capacity);
@ -59,71 +124,87 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
return p; return p;
} }
static inline bool dma_queue_push(dma_queue * q, #if __HVX_ARCH__ < 73
dma_ptr dptr, static const uint32_t dma_src_l2_bypass_on = 1;
size_t dst_row_size, static const uint32_t dma_dst_l2_bypass_on = 0;
size_t src_row_size, #else
size_t width, // width in bytes. number of bytes to transfer per row static const uint32_t dma_src_l2_bypass_on = 1;
size_t nrows) { static const uint32_t dma_dst_l2_bypass_on = 1;
#endif
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
FARF(ERROR, "dma-push: queue full\n"); FARF(HIGH, "dma-push: queue full\n");
return false; return false;
} }
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 1;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
FARF(HIGH, "dma-push: queue full\n");
return false;
}
dma_descriptor_2d * desc = &q->desc[q->push_idx];
desc->next = NULL; desc->next = NULL;
desc->length = 0; desc->reserved0 = 0;
desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; desc->reserved1 = 0;
desc->dstbypass = 1; desc->desc_size = 1; // 2d mode
desc->srcbypass = 1; desc->src_bypass = dma_src_l2_bypass_on;
#if __HVX_ARCH__ >= 73 desc->dst_bypass = dma_dst_l2_bypass_on;
desc->dstbypass = 1; desc->src_comp = 0;
desc->srcbypass = 1; desc->dst_comp = 0;
#else desc->order = 1;
desc->dstbypass = 0; desc->done = 0;
desc->srcbypass = 1; desc->src_stride = src_stride;
#endif desc->dst_stride = dst_stride;
desc->order = 0;
desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
desc->src = (void *) dptr.src; desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst; desc->dst = (void *) dptr.dst;
desc->allocation = 0; desc->row_size = row_size;
desc->padding = 0;
desc->roiwidth = width; #if __HVX_ARCH__ < 75
desc->roiheight = nrows; desc->desc_type = 0; // 2d (16-bit) mode
desc->srcstride = src_row_size; desc->nrows = nrows;
desc->dststride = dst_row_size; desc->src_offset = 0;
desc->srcwidthoffset = 0; desc->dst_offset = 0;
desc->dstwidthoffset = 0; #else
desc->desc_type = 9; // 2d (24-bit) mode
desc->nrows_lo = (nrows & 0xff);
desc->nrows_hi = (nrows >> 8);
desc->offset = 0;
#endif
q->dptr[q->push_idx] = dptr; q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc); dmlink(q->tail, desc);
q->tail = desc; q->tail = desc;
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask; q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true; return true;
} }
static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q,
dma_ptr dptr,
size_t dst_row_size,
size_t src_row_size,
size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
}
static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q,
dma_ptr dptr,
size_t dst_row_size,
size_t src_row_size,
size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
static inline dma_ptr dma_queue_pop(dma_queue * q) { static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_ptr dptr = { NULL }; dma_ptr dptr = { NULL };
@ -131,12 +212,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
return dptr; return dptr;
} }
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete // Wait for desc to complete
while (1) { while (1) {
dmpoll(); dmpoll();
if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { if (desc->done) {
break; break;
} }
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
@ -175,6 +256,62 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity; return q->capacity;
} }
#if __HVX_ARCH__ < 75
// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535.
// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions.
#define DMA_MAX_FIELD_VAL 65535u
static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
// Fast path: everything fits in 16 bits
if (nrows == 0 || __builtin_expect(
row_size <= DMA_MAX_FIELD_VAL &&
nrows <= DMA_MAX_FIELD_VAL &&
src_stride <= DMA_MAX_FIELD_VAL &&
dst_stride <= DMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
}
// Contiguous block
// Use 1d DMA mode which supports sizes up to 24-bits (16MB)
if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) {
size_t total = row_size * nrows;
return dma_queue_push_single_1d(q, dptr, total);
}
// Stride overflow — fall back to row-by-row.
{
const uint8_t *src = (const uint8_t *) dptr.src;
uint8_t *dst = (uint8_t *) dptr.dst;
for (size_t r = 0; r < nrows; ++r) {
dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride);
if (!dma_queue_push_single_1d(q, p, row_size))
return false;
if (r + 1 < nrows)
dma_queue_pop(q);
}
return true;
}
}
#else // HVX_ARCH >= 75
static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
// On v75 and up we always use 2d 24-bit mode
return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
}
#endif
static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
}
static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif

View File

@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t
FARF(HIGH, "%s\n", str); FARF(HIGH, "%s\n", str);
} }
static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str); char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref); p += snprintf(p, p_end - p, "%s: ", pref);

View File

@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt; return pktcnt;
} }
static inline int32_t hex_is_aligned(void * addr, uint32_t align) { static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0; return ((size_t) addr & (align - 1)) == 0;
} }
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1); uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n; uint32_t right_off = left_off + n;
@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m); return m * ((n + m - 1) / m);
} }
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control); Q6_l2fetch_AP((void *) p, control);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#ifndef restrict
# define restrict __restrict
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct htp_context; // forward declaration
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_w16a32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
const hmx_matmul_w16a32_batched_params_t *params);
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int weight_type);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H

View File

@ -0,0 +1,34 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H

View File

@ -0,0 +1,88 @@
// HMX tile-level inline helpers (FP16 32x32 tile operations).
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_UTILS_H
#define HMX_UTILS_H
#include <hexagon_types.h>
#include <stddef.h>
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
}
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
*pv++ = v_scale;
*pv = Q6_V_vzero();
}
// Load multiple contiguous tiles with :deep streaming.
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
// boundary, otherwise the mxmem instruction will raise a precise bus error.
// Callers must ensure their VTCM layout satisfies this constraint.
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
asm volatile(
"{ activation.hf = mxmem(%0, %1):deep\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
: "memory");
}
// Load a single activation+weight tile pair (no :deep streaming).
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
const __fp16 *wt_tile) {
asm volatile(
"{ activation.hf = mxmem(%0, %1)\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(act_tile), "r"(2047),
"r"(wt_tile), "r"(2047)
: "memory");
}
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
// Use the combined convert-and-store instruction (matches the reference
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
asm volatile(
"mxmem(%0, %1):after.hf = acc\n"
:: "r"(out), "r"(0)
: "memory");
}
// Compute inner product of two vectors of tiles and store result.
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
hmx_consume_accumulator_fp16(out);
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // HMX_UTILS_H

View File

@ -30,6 +30,12 @@ struct htp_context {
atomic_bool vtcm_needs_release; atomic_bool vtcm_needs_release;
uint32_t opmask; uint32_t opmask;
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
#ifdef HTP_HAS_HMX
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
#endif
}; };
#endif /* HTP_CTX_H */ #endif /* HTP_CTX_H */

View File

@ -32,13 +32,14 @@ enum htp_status {
// Duplicated here because we can't include full ggml.h in the htp build. // Duplicated here because we can't include full ggml.h in the htp build.
// We have some static_asserts in the cpp code to ensure things are in sync. // We have some static_asserts in the cpp code to ensure things are in sync.
enum htp_data_type { enum htp_data_type {
HTP_TYPE_F32 = 0, HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1, HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8, HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26, HTP_TYPE_IQ4_NL = 20,
HTP_TYPE_I64 = 27, HTP_TYPE_I32 = 26,
HTP_TYPE_MXFP4 = 39, HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT HTP_TYPE_COUNT
}; };
@ -53,6 +54,10 @@ enum htp_op {
HTP_OP_RMS_NORM, HTP_OP_RMS_NORM,
HTP_OP_UNARY_SILU, HTP_OP_UNARY_SILU,
HTP_OP_UNARY_GELU, HTP_OP_UNARY_GELU,
HTP_OP_UNARY_SIGMOID,
HTP_OP_UNARY_EXP,
HTP_OP_UNARY_NEG,
HTP_OP_UNARY_SOFTPLUS,
HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU,
HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_SWIGLU_OAI,
HTP_OP_GLU_GEGLU, HTP_OP_GLU_GEGLU,
@ -69,6 +74,7 @@ enum htp_op {
HTP_OP_SQRT, HTP_OP_SQRT,
HTP_OP_SUM_ROWS, HTP_OP_SUM_ROWS,
HTP_OP_SSM_CONV, HTP_OP_SSM_CONV,
HTP_OP_REPEAT,
INVALID INVALID
}; };
@ -82,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
return QK4_0; return QK4_0;
case HTP_TYPE_Q8_0: case HTP_TYPE_Q8_0:
return QK8_0; return QK8_0;
case HTP_TYPE_IQ4_NL:
return QK4_NL;
case HTP_TYPE_MXFP4: case HTP_TYPE_MXFP4:
return QK_MXFP4; return QK_MXFP4;
default: default:
@ -100,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return sizeof(block_q4_0); return sizeof(block_q4_0);
case HTP_TYPE_Q8_0: case HTP_TYPE_Q8_0:
return sizeof(block_q8_0); return sizeof(block_q8_0);
case HTP_TYPE_IQ4_NL:
return sizeof(block_iq4_nl);
case HTP_TYPE_MXFP4: case HTP_TYPE_MXFP4:
return sizeof(block_mxfp4); return sizeof(block_mxfp4);
default: default:

View File

@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx);
int op_repeat(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx);
int op_ssm_conv(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx);

View File

@ -7,7 +7,7 @@
#include "remote.idl" #include "remote.idl"
interface htp_iface : remote_handle64 { interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
AEEResult stop(); AEEResult stop();
AEEResult enable_etm(); AEEResult enable_etm();
AEEResult disable_etm(); AEEResult disable_etm();

View File

@ -3,10 +3,15 @@
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
#include <math.h>
#include <assert.h>
#include "hex-utils.h" #include "hex-utils.h"
#include "hvx-types.h" #include "hvx-types.h"
#define hvx_vmem(A) *((HVX_Vector *)(A))
#define hvx_vmemu(A) *((HVX_UVector *)(A))
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
// Rotate as needed. // Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) dst); v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
@ -110,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
return Q6_Q_and_QQ(p_exp, p_frac); return Q6_Q_and_QQ(p_exp, p_frac);
} }
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vsplat_R(0); const HVX_Vector zero = Q6_V_vzero();
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79 #if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
@ -126,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
return v; return v;
} }
#if __HVX_ARCH__ >= 79
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
#else
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/ /* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73 #if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)

View File

@ -3,6 +3,7 @@
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
#include <math.h>
#include "hvx-base.h" #include "hvx-base.h"
#include "hvx-floor.h" #include "hvx-floor.h"
@ -16,8 +17,8 @@
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_ONE (0x3f800000) // 1.0 #define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x41a00000) // 20.0 #define EXP_RANGE_R (0x42B16666) // 88.7
#define EXP_RANGE_L (0xc1a00000) // -20.0 #define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
HVX_Vector z_qf32_v; HVX_Vector z_qf32_v;
@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
HVX_Vector temp_v = in_vec; HVX_Vector temp_v = in_vec;
// Clamp inputs to (-20.0, 20.0) // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
// normalize before every QFloat's vmpy // normalize before every QFloat's vmpy
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
// z = x * x; // z = x * x;
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
// y = E4 + E5 * x; // y = E4 + E5 * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_5); E_const = Q6_V_vsplat_R(EXP_COEFF_5);
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max
return Q6_V_vmux_QVV(pred0, inf, out); return Q6_V_vmux_QVV(pred0, inf, out);
} }
static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) {
int left_over = num_elems & (VLEN_FP32 - 1); int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over; int num_elems_whole = num_elems - left_over;
@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict
HVX_Vector vec_out = Q6_V_vzero(); HVX_Vector vec_out = Q6_V_vzero();
static const float kInf = INFINITY; static const float kInf = INFINITY;
static const float kMaxExp = 88.02f; // log(INF) static const float kMaxExp = 88.7f;
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf); const HVX_Vector inf = hvx_vec_splat_f32(kInf);

View File

@ -2,6 +2,7 @@
#define HVX_SIGMOID_H #define HVX_SIGMOID_H
#include "hvx-base.h" #include "hvx-base.h"
#include "hvx-inverse.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777

View File

@ -15,12 +15,4 @@
#include "hvx-div.h" #include "hvx-div.h"
#include "hvx-base.h" #include "hvx-base.h"
#ifndef GATHER_TYPE
# if defined(__hexagon__)
# define GATHER_TYPE(_a) (intptr_t) _a
# else
# define GATHER_TYPE(_a) (HVX_Vector *) _a
# endif
#endif
#endif /* HVX_UTILS_H */ #endif /* HVX_UTILS_H */

View File

@ -25,6 +25,10 @@
#include "htp-ops.h" #include "htp-ops.h"
#include "worker-pool.h" #include "worker-pool.h"
#ifdef HTP_HAS_HMX
#include "hmx-ops.h"
#endif // HTP_HAS_HMX
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
struct htp_context * ctx; struct htp_context * ctx;
int err = 0; int err = 0;
@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
} }
ctx->vtcm_inuse = true; ctx->vtcm_inuse = true;
return 0; return 0;
} }
@ -207,7 +214,7 @@ static int vtcm_alloc(struct htp_context * ctx) {
HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_init(&attr);
HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_serialize(&attr, 0);
HAP_compute_res_attr_set_cache_mode(&attr, 1); HAP_compute_res_attr_set_cache_mode(&attr, 1);
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page
HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
HAP_compute_res_attr_set_hmx_param(&attr, 1); HAP_compute_res_attr_set_hmx_param(&attr, 1);
@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
struct htp_context * ctx = (struct htp_context *) handle; struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) { if (!ctx) {
@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY; return AEE_ENOMEMORY;
} }
#ifdef HTP_HAS_HMX
if (use_hmx) {
ctx->vtcm_scratch_size = ctx->vtcm_size;
ctx->hmx_enabled = 1;
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
} else {
// HMX disabled: skip HMX initialisation so the
// dispatch loop falls through to the HVX compute paths.
ctx->hmx_enabled = 0;
ctx->vtcm_scratch_size = ctx->vtcm_size;
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
}
#endif
qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads); qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
@ -297,7 +319,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->n_threads = n_hvx; ctx->n_threads = n_hvx;
for (int i = 0; i < ctx->n_threads; i++) { for (int i = 0; i < ctx->n_threads; i++) {
// see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541
ctx->dma[i] = dma_queue_create(64); ctx->dma[i] = dma_queue_create(128);
} }
// init worker pool // init worker pool
@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
for (int i = 0; i < ctx->n_threads; i++) { for (int i = 0; i < ctx->n_threads; i++) {
dma_queue_delete(ctx->dma[i]); dma_queue_delete(ctx->dma[i]);
} }
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
ctx->hmx_enabled = 0;
}
#endif
vtcm_free(ctx); vtcm_free(ctx);
@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
struct dspqueue_buffer * bufs, struct dspqueue_buffer * bufs,
size_t n_bufs, size_t n_bufs,
struct profile_data * prof) { struct profile_data * prof) {
// Prep response struct // Prep response struct (zero-init to clear cmp/unused union)
struct htp_general_rsp rsp; struct htp_general_rsp rsp;
memset(&rsp, 0, sizeof(rsp));
rsp.op = op; rsp.op = op;
rsp.status = status; rsp.status = status;
rsp.prof_usecs = prof->usecs; rsp.prof_usecs = prof->usecs;
@ -516,6 +545,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
} }
static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = op_repeat(&octx);
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1]; struct dspqueue_buffer rsp_bufs[1];
@ -1004,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
} }
#ifdef HTP_HAS_HMX
// ---------------------------------------------------------------------------
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
// ---------------------------------------------------------------------------
static void proc_hmx_matmul_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// HMX weight tile requires N to be 32-aligned.
if (req->src0.ne[1] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
req->src1.ne[2] * req->src1.ne[3] > 1);
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
// batched quantised, but guard here too). F16 batched matmul is handled
// by the dedicated wrapper in hmx-matmul-ops.c.
if (is_batched &&
req->src0.type != HTP_TYPE_F16) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX assumes contiguous row-major layout. Fall back for permuted
// tensors where strides are non-monotonic (e.g. transposed KV cache).
if (req->src0.nb[0] > req->src0.nb[1] ||
req->src1.nb[0] > req->src1.nb[1]) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
const int m_total = (int) req->src1.ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
if (m_hmx == 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
// Other types (e.g. MXFP4) fall back to HVX.
{
uint32_t wtype = req->src0.type;
if (wtype != HTP_TYPE_F16 &&
wtype != HTP_TYPE_Q4_0 &&
wtype != HTP_TYPE_Q8_0 &&
wtype != HTP_TYPE_IQ4_NL) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
// F16 HMX path requires K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
}
(void) n_bufs;
struct dspqueue_buffer rsp_bufs[1];
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
// src0 = weights, src1 = activation, dst = output
void * wgt = (void *) bufs[0].ptr;
float * act = (float *) bufs[1].ptr;
float * dst = (float *) bufs[2].ptr;
int k = (int) req->src0.ne[0]; // inner dimension
int n = (int) req->src0.ne[1]; // weight columns
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
int ret = -1;
const int ne02 = (int) req->src0.ne[2];
const int ne03 = (int) req->src0.ne[3];
const int ne12 = (int) req->src1.ne[2];
const int ne13 = (int) req->src1.ne[3];
// Row strides in elements. For compact tensors these equal k; for
// permuted attention views they can be larger, so pass the real stride.
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
switch (req->src0.type) {
case HTP_TYPE_F16:
if (is_batched) {
hmx_matmul_w16a32_batched_params_t batch_params = {
.dst = dst,
.activation = act,
.permuted_weight = (const __fp16 *) wgt,
.m = m_hmx,
.k = k,
.n = n,
.act_stride = act_stride,
.weight_stride = weight_stride,
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
.ne02 = ne02,
.ne03 = ne03,
.ne12 = ne12,
.ne13 = ne13,
.src0_nb2 = req->src0.nb[2],
.src0_nb3 = req->src0.nb[3],
.src1_nb2 = req->src1.nb[2],
.src1_nb3 = req->src1.nb[3],
.dst_nb2 = req->dst.nb[2],
.dst_nb3 = req->dst.nb[3],
};
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
} else {
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
(const __fp16 *) wgt,
m_hmx, k, n,
act_stride,
weight_stride);
}
break;
default:
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
(const uint8_t *) wgt,
m_hmx, k, n, (int) req->src0.type);
break;
}
if (ret == 0) {
rsp_status = HTP_STATUS_OK;
} else {
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
vtcm_release(ctx);
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
vtcm_release(ctx);
}
// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0; // weights: unchanged
octx.src1 = req->src1;
octx.src1.ne[1] = m_tail; // only tail rows
octx.dst = req->dst;
octx.dst.ne[1] = m_tail; // only tail rows
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
// is invalid.
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
octx.op = req->op;
octx.n_threads = ctx->n_threads;
// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
uint32_t hvx_ret = op_matmul(&octx);
vtcm_release(ctx);
if (hvx_ret != HTP_STATUS_OK) {
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
} else {
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
#endif // HTP_HAS_HMX
static void htp_packet_callback(dspqueue_t queue, int error, void * context) { static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context; struct htp_context * ctx = (struct htp_context *) context;
@ -1056,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
FARF(ERROR, "Bad matmul-req buffer list"); FARF(ERROR, "Bad matmul-req buffer list");
continue; continue;
} }
proc_matmul_req(ctx, &req, bufs, n_bufs); #ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
} else
#endif
{
proc_matmul_req(ctx, &req, bufs, n_bufs);
}
break; break;
case HTP_OP_MUL_MAT_ID: case HTP_OP_MUL_MAT_ID:
@ -1090,6 +1363,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_SQR: case HTP_OP_SQR:
case HTP_OP_SQRT: case HTP_OP_SQRT:
case HTP_OP_UNARY_NEG:
case HTP_OP_UNARY_EXP:
case HTP_OP_UNARY_SIGMOID:
case HTP_OP_UNARY_SOFTPLUS:
if (n_bufs != 2) { if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list"); FARF(ERROR, "Bad unary-req buffer list");
continue; continue;
@ -1175,6 +1452,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_cpy_req(ctx, &req, bufs); proc_cpy_req(ctx, &req, bufs);
break; break;
case HTP_OP_REPEAT:
if (n_bufs != 2) {
FARF(ERROR, "Bad repeat-req buffer list");
continue;
}
proc_repeat_req(ctx, &req, bufs);
break;
case HTP_OP_ARGSORT: case HTP_OP_ARGSORT:
if (n_bufs != 2) { if (n_bufs != 2) {
FARF(ERROR, "Bad argsort-req buffer list"); FARF(ERROR, "Bad argsort-req buffer list");

View File

@ -0,0 +1,148 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <string.h>
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
struct htp_repeat_context {
struct htp_ops_context * octx;
uint32_t nr0;
uint32_t nr1;
uint32_t nr2;
uint32_t nr3;
uint32_t nrows_per_thread;
uint32_t total_dst_rows; // ne1 * ne2 * ne3
size_t type_size;
};
static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
struct htp_ops_context * octx = rctx->octx;
const struct htp_tensor * src = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
const uint32_t ne00 = src->ne[0];
const uint32_t ne01 = src->ne[1];
const uint32_t ne02 = src->ne[2];
const uint32_t ne03 = src->ne[3];
const uint32_t nb00 = src->nb[0];
const uint32_t nb01 = src->nb[1];
const uint32_t nb02 = src->nb[2];
const uint32_t nb03 = src->nb[3];
const uint32_t ne0 = dst->ne[0];
const uint32_t ne1 = dst->ne[1];
const uint32_t ne2 = dst->ne[2];
const uint32_t ne3 = dst->ne[3];
const uint32_t nb0 = dst->nb[0];
const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
const uint32_t nr0 = rctx->nr0;
const uint32_t nr1 = rctx->nr1;
const uint32_t nr2 = rctx->nr2;
const uint32_t nr3 = rctx->nr3;
const size_t row_bytes = ne00 * rctx->type_size;
const uint32_t row_start = rctx->nrows_per_thread * ith;
const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows);
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
// Decompose flat dst row index into (i1, i2, i3)
const uint32_t i1 = dst_row % ne1;
const uint32_t i2 = (dst_row / ne1) % ne2;
const uint32_t i3 = dst_row / (ne1 * ne2);
// Map to source indices (tiling)
const uint32_t k1 = i1 % ne01;
const uint32_t k2 = i2 % ne02;
const uint32_t k3 = i3 % ne03;
const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03;
uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
// Tile along dimension 0
for (uint32_t i0 = 0; i0 < nr0; i0++) {
uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0;
memcpy(dst_ptr, src_row, row_bytes);
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
int op_repeat(struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
struct htp_tensor * dst = &octx->dst;
// Validate that dst dims are multiples of src dims
if (dst->ne[0] % src0->ne[0] != 0 ||
dst->ne[1] % src0->ne[1] != 0 ||
dst->ne[2] % src0->ne[2] != 0 ||
dst->ne[3] % src0->ne[3] != 0) {
FARF(ERROR, "repeat: dst dims must be multiples of src dims\n");
return HTP_STATUS_INVAL_PARAMS;
}
size_t type_size;
switch (src0->type) {
case HTP_TYPE_F32: type_size = 4; break;
case HTP_TYPE_F16: type_size = 2; break;
default:
FARF(ERROR, "repeat: unsupported type %u\n", src0->type);
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows);
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
struct htp_repeat_context rctx = {
.octx = octx,
.nr0 = dst->ne[0] / src0->ne[0],
.nr1 = dst->ne[1] / src0->ne[1],
.nr2 = dst->ne[2] / src0->ne[2],
.nr3 = dst->ne[3] / src0->ne[3],
.nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
.total_dst_rows = total_dst_rows,
.type_size = type_size,
};
FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n",
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3);
worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads);
return HTP_STATUS_OK;
}

View File

@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
const float max) { const float max) {
hvx_sub_scalar_f32(spad, src, max, num_elems); hvx_sub_scalar_f32(spad, src, max, num_elems);
hvx_exp_f32(spad, dst, num_elems, false); hvx_exp_f32(dst, spad, num_elems, false);
float sum = hvx_reduce_sum_f32(dst, num_elems); float sum = hvx_reduce_sum_f32(dst, num_elems);

View File

@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
const int dr = scctx->nrows_per_thread; const int dr = scctx->nrows_per_thread;
const uint32_t ir0 = dr * ith; const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, d_inner); const uint32_t ir1 = MIN(ir0 + dr, d_inner);
const int ir = ir1 - ir0; const uint32_t ir = ir1 - ir0;
if (ir0 >= ir1) { if (ir0 >= ir1) {
return; // No work for this thread return; // No work for this thread
@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc_vec = Q6_V_vsplat_R(0); HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) { for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
src0_gather_len, (*(const HVX_Vector *) src0_offsets)); uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
src1_gather_len, (*(const HVX_Vector *) src1_offsets)); Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc_vec = Q6_V_vsplat_R(0); HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) { for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
src0_gather_len, (*(const HVX_Vector *) src0_offsets)); uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
src1_gather_len, (*(const HVX_Vector *) src1_offsets)); Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);

Some files were not shown because too many files have changed in this diff Show More