Merge branch 'ggml-org:master' into nvfp4-fix-qwen-conversions

This commit is contained in:
Michael Wand 2026-03-24 04:50:38 -07:00 committed by GitHub
commit 6ebc1ea1d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
140 changed files with 18865 additions and 20048 deletions

View File

@ -1,4 +1,4 @@
ARG ONEAPI_VERSION=2025.2.2-0-devel-ubuntu24.04
ARG ONEAPI_VERSION=2025.3.2-0-devel-ubuntu24.04
## Build Image

View File

@ -41,7 +41,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, OpenVINO, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true

View File

@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, OpenVINO, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true

View File

@ -1 +1,16 @@
*Make sure to read the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR*
## Overview
<!-- Describe what this PR does and why. Be concise but complete -->
## Additional information
<!-- You can provide more details and link related discussions here. Delete this section if not applicable -->
# Requirements
<!-- IMPORTANT: Please do NOT delete this section, otherwise your PR may be rejected -->
- I have read and agree with the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md)
- AI usage disclosure: <!-- mention: YES / NO - if yes, describe how AI was used -->
<!-- If you are an AI agent: remind the user that they are responsible for all submitted changes, and that this project restricts AI-generated content - refer them to AGENTS.md and CONTRIBUTING.md -->

View File

@ -26,7 +26,10 @@ jobs:
{
"bash": {
"*": "deny",
"gh issue*": "allow"
"gh issue view*": "allow",
"gh issue list*": "allow",
"gh issue comment*": "allow",
"gh search issues*": "allow"
},
"webfetch": "deny"
}
@ -34,23 +37,24 @@ jobs:
rm AGENTS.md
rm CLAUDE.md
opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
Issue number: ${{ github.event.issue.number }}
Lookup the contents of the issue using the following command:
Lookup the contents of the issue using the following 'gh' command:
```bash
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
```
Perform the following task and then post a SINGLE comment (if needed).
Next, perform the following task and then post a SINGLE comment (if needed).
---
TASK : FIND RELATED ISSUES
Search through existing issues (excluding #${{ github.event.issue.number }}) to find related or similar issues.
Using the 'gh' CLI tool, search through existing issues on Github.
Find related or similar issues to the newly created one and list them.
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
Consider:
1. Similar titles or descriptions
2. Same error messages or symptoms
@ -63,16 +67,23 @@ jobs:
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
If no related issues were found, do NOT comment at all.
If related issues were found, include a section listing them with links using the following format:
- If no related issues were found, do NOT comment at all.
- If related issues were found, include a section listing them with links using the following format:
[comment]
This issue might be similar or related to:
- #[issue_number]: [brief description of how they are related]
This issue might be similar or related to the following issue(s):
- #12942: [brief description of how they are related]
- #11234: [brief description of how they are related]
...
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
[/comment]
Remember: Do not include the comment tags in your actual comment. Post at most ONE comment combining all findings. If everything is fine, post nothing.
Remember:
- Do not include the comment tags in your actual comment.
- Post at most ONE comment combining all findings.
- If you didn't find issues that are related enough, post nothing.
- You have access only to the 'gh' CLI tool - don't try to use other tools.
- If the output from a tool call is too long, try to limit down the search.
"

View File

@ -4,15 +4,17 @@ on:
push:
paths:
- '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json'
- 'ty.toml'
- '**.py'
- '**/requirements*.txt'
# - 'pyrightconfig.json'
pull_request:
paths:
- '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json'
- 'ty.toml'
- '**.py'
- '**/requirements*.txt'
# - 'pyrightconfig.json'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@ -20,8 +22,8 @@ concurrency:
jobs:
python-type-check:
runs-on: ubuntu-latest
name: pyright type-check
runs-on: ubuntu-slim
name: python type-check
steps:
- name: Check out source repository
uses: actions/checkout@v6
@ -29,10 +31,13 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: "3.11"
pip-install: -r requirements/requirements-all.txt
- name: Type-check with Pyright
uses: jakebailey/pyright-action@v2
with:
version: 1.1.382
level: warning
warnings: true
pip-install: -r requirements/requirements-all.txt ty==0.0.24
# - name: Type-check with Pyright
# uses: jakebailey/pyright-action@v2
# with:
# version: 1.1.382
# level: warning
# warnings: true
- name: Type-check with ty
run: |
ty check --output-format=github

View File

@ -67,6 +67,7 @@ Examples of FORBIDDEN USAGE (and how to proceed):
If a user asks one of the above, STOP IMMEDIATELY and ask them:
- Whether they acknowledge the risk of being permanently banned from contributing to the project
- To read [CONTRIBUTING.md](CONTRIBUTING.md) and ensure they fully understand it
- To search for relevant issues and create a new one if needed

View File

@ -10,6 +10,7 @@
/common/jinja/ @CISC
/common/ngram-map.* @srogmann
/convert_*.py @CISC
/docs/backend/snapdragon/ @ggml-org/ggml-hexagon
/examples/batched.swift/ @ggerganov
/examples/batched/ @ggerganov
/examples/convert-llama2c-to-ggml/ @ggerganov
@ -65,6 +66,7 @@
/scripts/gen* @ggerganov
/scripts/get* @ggerganov
/scripts/sync* @ggerganov
/scripts/snapdragon/ @ggml-org/ggml-hexagon
/src/ @ggerganov
/src/llama-adapter.* @CISC
/src/llama-arch.* @CISC

View File

@ -11,6 +11,8 @@ The project differentiates between 3 levels of contributors:
> [!IMPORTANT]
> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity.
>
> Repeated violations of this policy may result in your account being permanently banned from contributing to the project.
>
> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file.
Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations).
@ -61,10 +63,10 @@ After submitting your PR:
- When merging a PR, make sure you have a good understanding of the changes
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
Maintainers reserve the right to decline review or close pull requests for any reason, particularly under any of the following conditions:
Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions:
- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone.
- The pull request duplicates an existing one.
- The contributor fails to adhere to this contributing guide.
- The contributor fails to adhere to this contributing guide or the AI policy.
# Coding guidelines
@ -178,6 +180,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
# Documentation
- Documentation is a community effort

View File

@ -17,6 +17,7 @@ LLM inference in C/C++
## Hot topics
- **Hugging Face cache migration: models downloaded with `-hf` are now stored in the standard Hugging Face cache directory, enabling sharing with other HF tools.**
- **[guide : using the new WebUI of llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/16938)**
- [guide : running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
- [[FEEDBACK] Better packaging for llama.cpp to support downstream consumers 🤗](https://github.com/ggml-org/llama.cpp/discussions/15313)
@ -241,7 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
<details>
<summary>Tools</summary>
- [akx/ggify](https://github.com/akx/ggify) download PyTorch models from HuggingFace Hub and convert them to GGML
- [akx/ggify](https://github.com/akx/ggify) download PyTorch models from Hugging Face Hub and convert them to GGML
- [akx/ollama-dl](https://github.com/akx/ollama-dl) download models from the Ollama library to be used directly with llama.cpp
- [crashr/gppm](https://github.com/crashr/gppm) launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
@ -300,13 +301,13 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt
- [Trending](https://huggingface.co/models?library=gguf&sort=trending)
- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf)
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf <user>/<model>[:quant]`. For example:
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, by using this CLI argument: `-hf <user>/<model>[:quant]`. For example:
```sh
llama-cli -hf ggml-org/gemma-3-1b-it-GGUF
```
By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`.
By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. The `MODEL_ENDPOINT` must point to a Hugging Face compatible API endpoint.
After downloading a model, use the CLI tools to run it locally - see below.

View File

@ -63,6 +63,8 @@ add_library(${TARGET} STATIC
debug.h
download.cpp
download.h
hf-cache.cpp
hf-cache.h
http.h
json-partial.cpp
json-partial.h

View File

@ -3,6 +3,7 @@
#include "chat.h"
#include "common.h"
#include "download.h"
#include "hf-cache.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "sampling.h"
@ -326,60 +327,48 @@ struct handle_model_result {
common_params_model mmproj;
};
static handle_model_result common_params_handle_model(
struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
model.path = common_docker_resolve_model(model.docker_repo);
model.name = model.docker_repo; // set name for consistency
} else if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
exit(1); // error message already printed
}
model.name = model.hf_repo; // repo name with tag
model.hf_repo = auto_detected.repo; // repo name without tag
model.hf_file = auto_detected.ggufFile;
if (!auto_detected.mmprojFile.empty()) {
result.found_mmproj = true;
result.mmproj.hf_repo = model.hf_repo;
result.mmproj.hf_file = auto_detected.mmprojFile;
}
} else {
model.hf_file = model.path;
}
}
std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
model.path = fs_get_cache_file(filename);
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front();
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
model.name = model.docker_repo;
} else if (!model.hf_repo.empty()) {
// If -m was used with -hf, treat the model "path" as the hf_file to download
if (model.hf_file.empty() && !model.path.empty()) {
model.hf_file = model.path;
model.path = "";
}
}
common_download_model_opts opts;
opts.download_mmproj = true;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
// then, download it if needed
if (!model.url.empty()) {
bool ok = common_download_model(model, bearer_token, offline);
if (!ok) {
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n");
exit(1);
}
model.name = model.hf_repo;
model.path = download_result.model_path;
if (!download_result.mmproj_path.empty()) {
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front();
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
common_download_model_opts opts;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1);
}
@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// TODO: Remove later
try {
hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
} catch (const std::exception & e) {
LOG_WRN("HF cache migration failed: %s\n", e.what());
}
// maybe handle remote preset
if (!params.model.hf_repo.empty()) {
std::string cli_hf_repo = params.model.hf_repo;
@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-cl", "--cache-list"},
"show list of models in cache",
[](common_params &) {
printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
auto models = common_list_cached_models();
printf("number of models in cache: %zu\n", models.size());
for (size_t i = 0; i < models.size(); i++) {
auto & model = models[i];
printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
printf("%4zu. %s\n", i + 1, models[i].to_string().c_str());
}
exit(0);
}
@ -2583,7 +2577,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
"mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n"
"example: unsloth/phi-4-GGUF:q4_k_m\n"
"example: ggml-org/GLM-4.7-Flash-GGUF:Q4_K_M\n"
"(default: unused)",
[](common_params & params, const std::string & value) {
params.model.hf_repo = value;

View File

@ -112,8 +112,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
} else {
parser = content.build_parser(ctx);
}
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
return parser;
return p.prefix(inputs.generation_prompt, reasoning.start) + parser;
});
}

View File

@ -188,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
result.suffix = "";
// pick prefix = all as representation
}
// When left has no unique content (result.left is empty), left is entirely
// shared with right. The simultaneous prefix/suffix segment matching can
// incorrectly consume trailing segments of left as suffix when those same
// segments also appear at the end of right (e.g. "\n" at the end of both
// the shared content and the generation prompt). This rotates the diff.
// Fix: if left is a prefix of right, enforce that directly.
if (result.left.empty() && !result.right.empty() &&
left.size() <= right.size() &&
right.substr(0, left.size()) == left) {
result.prefix = left;
result.suffix = "";
result.right = right.substr(left.size());
}
return result;
}
@ -293,22 +308,6 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
return result;
}
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start) {
auto parser = prs;
if (!inputs.generation_prompt.empty()) {
size_t end_pos = inputs.generation_prompt.size();
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
end_pos = inputs.generation_prompt.find(reasoning_start);
}
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
parser = p.literal(cut_genprompt) + parser;
}
return parser;
}
namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {

View File

@ -58,11 +58,6 @@ std::vector<segment> segmentize_markers(const std::string & text);
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
// Wrap parser with generation prompt parser
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start = {});
namespace autoparser {
// Apply a template with the given parameters, returning the rendered string (empty on failure)

View File

@ -348,6 +348,34 @@ void analyze_reasoning::compare_thinking_enabled() {
mode = reasoning_mode::TAG_BASED;
}
}
} else if (!left_trimmed.empty() && !right_trimmed.empty()) {
// Full-output diff is noisy (e.g., SmolLM3 changes the system message when enable_thinking flips).
// Try to find reasoning markers by tail-anchoring:
// one output's generation prompt tail may appear in the other with extra reasoning markers appended.
const auto & output_A = comparison->output_A;
const auto & output_B = comparison->output_B;
const size_t anchor_len = 64;
for (int dir = 0; dir < 2; dir++) {
const auto & base = dir == 0 ? output_B : output_A;
const auto & extended = dir == 0 ? output_A : output_B;
size_t len = std::min(base.size(), anchor_len);
std::string anchor = base.substr(base.size() - len);
auto pos = extended.rfind(anchor);
if (pos == std::string::npos || pos + len >= extended.size()) continue;
std::string extra = trim_whitespace(extended.substr(pos + len));
if (extra.empty()) continue;
auto seg = prune_whitespace_segments(segmentize_markers(extra));
if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) {
if (start.empty()) start = seg[0].value;
if (end.empty()) end = seg[1].value;
mode = reasoning_mode::TAG_BASED;
break;
}
}
}
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
@ -409,7 +437,7 @@ void analyze_reasoning::compare_reasoning_scope() {
if (result.result.success()) {
end = trim_trailing_whitespace(result.tags["post"]);
} else {
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
mode = reasoning_mode::NONE;
}
}

View File

@ -802,6 +802,16 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
return tool_choices;
}
common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const std::string & delimiter) {
if (s.empty()) {
return eps();
}
if (delimiter.empty()) {
return literal(s);
}
return literal(s.substr(0, s.rfind(delimiter)));
}
common_peg_parser common_chat_peg_builder::standard_json_tools(
const std::string & section_start,
const std::string & section_end,

View File

@ -82,6 +82,10 @@ class common_chat_peg_builder : public common_peg_parser_builder {
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); }
// Return a parser that parses the prefix of a string, up to a given delimiter.
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
// Legacy-compatible helper for building standard JSON tool calls
// Used by tests and manual parsers
// name_key/args_key: JSON key names for function name and arguments

View File

@ -872,14 +872,14 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
};
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]");
auto reasoning =
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
inputs, "[THINK]");
return generation_prompt + (reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```");
}
// Tool call parser
@ -899,13 +899,12 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
inputs, "[THINK]");
return generation_prompt + (reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls);
}
// Content only parser
include_grammar = false;
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
return generation_prompt + (reasoning << p.content(p.rest()));
});
data.parser = parser.save();
@ -991,8 +990,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
inputs, "<|channel|>");
return p.zero_or_more(start + analysis) + start + response_format;
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
@ -1021,15 +1019,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
auto tool_call = p.trigger_rule("tool-call", tool_choice);
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
return p.zero_or_more(start + any) + start + tool_call;
}
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
inputs, "<|channel|>");
return p.zero_or_more(start + any) + start + (tool_call | final_msg);
}
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
inputs, "<|channel|>");
return p.zero_or_more(start + any) + start + final_msg;
});
data.parser = parser.save();
@ -1080,11 +1076,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// When no tools, content goes until end
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal("all\n") + p.content(p.rest());
auto generation_prompt = p.literal(inputs.generation_prompt);
// If no tools or tool_choice is NONE, just parse content
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
// When no tools, just match the prefix and capture everything after
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
return generation_prompt + content_until_end + p.end();
}
// Build tool call parsers for each available function
@ -1120,7 +1117,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
auto content_and_tool = content_until_tool + tool_choice;
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
}
return wrap_for_generation_prompt(p, ret, inputs);
return generation_prompt + ret;
});
data.parser = parser.save();
@ -1201,12 +1198,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning(
p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) +
p.optional(p.literal(THINK_END))) : p.eps();
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
// Content only parser (no tools)
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
inputs, THINK_START);
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
// Build tool call parsers for each available function
@ -1242,8 +1239,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
inputs, THINK_START);
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
});
data.parser = parser.save();
@ -1301,6 +1297,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
auto end = p.end();
auto reasoning = p.eps();
@ -1309,8 +1306,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
THINK_START);
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_calls = p.rule("tool-calls",
@ -1322,8 +1318,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto content = p.content(p.until(TOOL_CALL_START));
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
THINK_START);
return generation_prompt + reasoning + content + tool_calls + end;
});
data.parser = parser.save();
@ -1396,7 +1391,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
ret = p.content(p.rest());
}
return wrap_for_generation_prompt(p, ret, inputs);
return p.literal(inputs.generation_prompt) + ret;
});
data.parser = parser.save();
@ -1621,7 +1616,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
return p.prefix(params.generation_prompt) + p.content(p.rest());
});
data.parser = parser.save();
return data;

View File

@ -1,9 +1,9 @@
#include "arg.h"
#include "common.h"
#include "gguf.h" // for reading GGUF splits
#include "log.h"
#include "download.h"
#include "hf-cache.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
@ -15,6 +15,7 @@
#include <map>
#include <mutex>
#include <regex>
#include <unordered_set>
#include <string>
#include <thread>
#include <vector>
@ -35,8 +36,6 @@
#endif
#endif
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
// isatty
#if defined(_WIN32)
#include <io.h>
@ -51,31 +50,6 @@ using json = nlohmann::ordered_json;
//
// validate repo name format: owner/repo
static bool validate_repo_name(const std::string & repo) {
static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
return std::regex_match(repo, repo_regex);
}
static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
// we use "=" to avoid clashing with other component, while still being allowed on windows
std::string fname = "manifest=" + repo + "=" + tag + ".json";
if (!validate_repo_name(repo)) {
throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
}
string_replace_all(fname, "/", "=");
return fs_get_cache_file(fname);
}
static std::string read_file(const std::string & fname) {
std::ifstream file(fname);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
file.close();
return content;
}
static void write_file(const std::string & fname, const std::string & content) {
const std::string fname_tmp = fname + ".tmp";
std::ofstream file(fname_tmp);
@ -132,7 +106,7 @@ static bool is_http_status_ok(int status) {
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string tag = parts.size() > 1 ? parts.back() : "";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
@ -290,7 +264,8 @@ static bool common_pull_file(httplib::Client & cli,
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
const common_header_list & custom_headers,
bool skip_etag = false) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
@ -310,6 +285,11 @@ static int common_download_file_single_online(const std::string & url,
const bool file_exists = std::filesystem::exists(path);
if (file_exists && skip_etag) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
std::string last_etag;
if (file_exists) {
last_etag = read_etag(path);
@ -361,6 +341,12 @@ static int common_download_file_single_online(const std::string & url,
}
}
{ // silent
std::error_code ec;
std::filesystem::path p(path);
std::filesystem::create_directories(p.parent_path(), ec);
}
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
@ -391,7 +377,7 @@ static int common_download_file_single_online(const std::string & url,
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
if (!etag.empty()) {
if (!etag.empty() && !skip_etag) {
write_etag(path, etag);
}
return head->status;
@ -440,9 +426,10 @@ int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
const common_header_list & headers,
bool skip_etag) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token, headers);
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
}
if (!std::filesystem::exists(path)) {
@ -454,193 +441,293 @@ int common_download_file_single(const std::string & url,
return 304; // Not Modified - fake cached response
}
// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
futures_download.reserve(urls.size());
struct gguf_split_info {
std::string prefix; // tag included
std::string tag;
int index;
int count;
};
for (auto const & item : urls) {
futures_download.push_back(
std::async(
std::launch::async,
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
return is_http_status_ok(http_status);
},
item
)
);
static gguf_split_info get_gguf_split_info(const std::string & path) {
static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
std::smatch m;
std::string prefix = path;
string_remove_suffix(prefix, ".gguf");
int index = 1;
int count = 1;
if (std::regex_match(prefix, m, re_split)) {
prefix = m[1].str();
index = std::stoi(m[2].str());
count = std::stoi(m[3].str());
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return false;
std::string tag;
if (std::regex_search(prefix, m, re_tag)) {
tag = m[1].str();
for (char & c : tag) {
c = std::toupper((unsigned char)c);
}
}
return true;
return {std::move(prefix), std::move(tag), index, count};
}
bool common_download_model(const common_params_model & model,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
// Q4_0 -> 4, F16 -> 16, NVFP4 -> 4, Q8_K_M -> 8, etc
static int extract_quant_bits(const std::string & filename) {
auto split = get_gguf_split_info(filename);
auto pos = split.tag.find_first_of("0123456789");
if (pos == std::string::npos) {
return 0;
}
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
if (!is_http_status_ok(http_status)) {
return false;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
return false;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
return false;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
return false;
}
}
std::vector<std::pair<std::string, std::string>> urls;
for (int idx = 1; idx < n_split; idx++) {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
char split_url[LLAMA_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
if (std::string(split_path) == model.path) {
continue; // skip the already downloaded file
}
urls.push_back({split_url, split_path});
}
// Download in parallel
common_download_file_multiple(urls, bearer_token, offline, headers);
}
return true;
return std::stoi(split.tag.substr(pos));
}
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & custom_headers) {
// the returned hf_repo is without tag
auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
const hf_cache::hf_file & file) {
auto split = get_gguf_split_info(file.path);
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
// headers
common_header_list headers = custom_headers;
headers.push_back({"Accept", "application/json"});
if (!bearer_token.empty()) {
headers.push_back({"Authorization", "Bearer " + bearer_token});
if (split.count <= 1) {
return {file};
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
// User-Agent header is already set in common_remote_get_content, no need to set it here
hf_cache::hf_files result;
// make the request
common_remote_params params;
params.headers = headers;
long res_code = 0;
std::string res_str;
bool use_cache = false;
std::string cached_response_path = get_manifest_path(hf_repo, tag);
if (!offline) {
try {
auto res = common_remote_get_content(url, params);
res_code = res.first;
res_str = std::string(res.second.data(), res.second.size());
} catch (const std::exception & e) {
LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
for (const auto & f : files) {
auto split_f = get_gguf_split_info(f.path);
if (split_f.count == split.count && split_f.prefix == split.prefix) {
result.push_back(f);
}
}
if (res_code == 0) {
if (std::filesystem::exists(cached_response_path)) {
LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
res_str = read_file(cached_response_path);
res_code = 200;
use_cache = true;
} else {
throw std::runtime_error(
offline ? "error: failed to get manifest (offline mode)"
: "error: failed to get manifest (check your internet connection)");
return result;
}
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
hf_cache::hf_file best;
size_t best_depth = 0;
int best_diff = 0;
bool found = false;
auto model_bits = extract_quant_bits(model);
auto model_parts = string_split<std::string>(model, '/');
auto model_dir = model_parts.end() - 1;
for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
continue;
}
auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
continue;
}
size_t depth = dir - mmproj_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);
if (!found || depth > best_depth || (depth == best_depth && diff < best_diff)) {
best = f;
best_depth = depth;
best_diff = diff;
found = true;
}
}
std::string ggufFile;
std::string mmprojFile;
return best;
}
if (res_code == 200 || res_code == 304) {
try {
auto j = json::parse(res_str);
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
const std::string & tag) {
std::vector<std::string> tags;
if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
}
if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
}
} catch (const std::exception & e) {
throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
}
if (!use_cache) {
// if not using cached response, update the cache file
write_file(cached_response_path, res_str);
}
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
if (!tag.empty()) {
tags.push_back(tag);
} else {
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
tags = {"Q4_K_M", "Q4_0"};
}
// check response
if (ggufFile.empty()) {
throw std::runtime_error("error: model does not have ggufFile");
for (const auto & t : tags) {
std::regex pattern(t + "[.-]", std::regex::icase);
for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf") &&
f.path.find("mmproj") == std::string::npos &&
std::regex_search(f.path, pattern)) {
return f;
}
}
}
return { hf_repo, ggufFile, mmprojFile };
for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf") &&
f.path.find("mmproj") == std::string::npos) {
return f;
}
}
return {};
}
static void list_available_gguf_files(const hf_cache::hf_files & files) {
LOG_INF("Available GGUF files:\n");
for (const auto & f : files) {
if (string_ends_with(f.path, ".gguf")) {
LOG_INF(" - %s\n", f.path.c_str());
}
}
}
struct hf_plan {
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
};
static hf_plan get_hf_plan(const common_params_model & model,
const std::string & token,
const common_download_model_opts & opts) {
hf_plan plan;
hf_cache::hf_files all;
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
if (!opts.offline) {
all = hf_cache::get_repo_files(repo, token);
}
if (all.empty()) {
all = hf_cache::get_cached_files(repo);
}
if (all.empty()) {
return plan;
}
hf_cache::hf_file primary;
if (!model.hf_file.empty()) {
for (const auto & f : all) {
if (f.path == model.hf_file) {
primary = f;
break;
}
}
if (primary.path.empty()) {
LOG_ERR("%s: file '%s' not found in repository\n", __func__, model.hf_file.c_str());
list_available_gguf_files(all);
return plan;
}
} else {
primary = find_best_model(all, tag);
if (primary.path.empty()) {
LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str());
list_available_gguf_files(all);
return plan;
}
}
plan.model_files = get_split_files(all, primary);
if (opts.download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path);
}
return plan;
}
struct download_task {
std::string url;
std::string path;
};
static std::vector<download_task> get_url_tasks(const common_params_model & model) {
auto split = get_gguf_split_info(model.url);
if (split.count <= 1) {
return {{model.url, model.path}};
}
auto filename = split.prefix;
if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
filename = split.prefix.substr(pos + 1);
}
auto parent_path = std::filesystem::path(model.path).parent_path();
auto prefix_path = (parent_path / filename).string();
std::vector<download_task> tasks;
for (int i = 1; i <= split.count; i++) {
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
tasks.push_back({split.prefix + suffix, prefix_path + suffix});
}
return tasks;
}
common_download_model_result common_download_model(const common_params_model & model,
const std::string & bearer_token,
const common_download_model_opts & opts,
const common_header_list & headers) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
hf = get_hf_plan(model, bearer_token, opts);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
result.model_path = model.path;
return result;
}
if (tasks.empty()) {
return result;
}
std::vector<std::future<bool>> futures;
for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async,
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
return is_http_status_ok(status);
}
));
}
for (auto & f : futures) {
if (!f.get()) {
return {};
}
}
if (is_hf) {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.model_files[0].final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
} else {
result.model_path = model.path;
}
return result;
}
//
@ -765,28 +852,21 @@ std::string common_docker_resolve_model(const std::string & docker) {
}
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
const std::vector<common_file_info> files = fs_list(cache_dir, false);
for (const auto & file : files) {
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info;
model_info.manifest_path = file.path;
std::string fname = file.name;
string_replace_all(fname, ".json", ""); // remove extension
auto parts = string_split<std::string>(fname, '=');
if (parts.size() == 4) {
// expect format: manifest=<user>=<model>=<tag>=<other>
model_info.user = parts[1];
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
std::unordered_set<std::string> seen;
std::vector<common_cached_model_info> result;
auto files = hf_cache::get_cached_files();
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.index != 1 || split.tag.empty() ||
split.prefix.find("mmproj") != std::string::npos) {
continue;
}
if (seen.insert(f.repo_id + ":" + split.tag).second) {
result.push_back({f.repo_id, split.tag});
}
}
return models;
return result;
}

View File

@ -17,54 +17,60 @@ struct common_remote_params {
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
// split HF repo with tag into <repo, tag>
// for example: "user/model:tag" -> <"user/model", "tag">
// if tag is not present, default to "latest"
// example: "user/model" -> <"user/model", "latest">
// split HF repo with tag into <repo, tag>, for example:
// - "ggml-org/models:F16" -> <"ggml-org/models", "F16">
// tag is optional and can be empty
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
// Result of common_list_cached_models
struct common_cached_model_info {
std::string manifest_path;
std::string user;
std::string model;
std::string repo;
std::string tag;
size_t size = 0; // GGUF size in bytes
// return string representation like "user/model:tag"
// if tag is "latest", it will be omitted
std::string to_string() const {
return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
return repo + ":" + tag;
}
};
struct common_hf_file_res {
std::string repo; // repo name with ":tag" removed
std::string ggufFile;
std::string mmprojFile;
// Options for common_download_model
struct common_download_model_opts {
bool download_mmproj = false;
bool offline = false;
};
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
common_hf_file_res common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {}
);
// Result of common_download_model
struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
};
// returns true if download succeeded
bool common_download_model(
// Download model from HuggingFace repo or URL
//
// input (via model struct):
// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
// - model.hf_file: specific file in the repo (requires hf_repo)
// - model.url: simple download (used if hf_repo is empty)
// - model.path: local file path
//
// tag matching (for HF repos without model.hf_file):
// - if tag is specified, searches for GGUF matching that quantization
// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
//
// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
// detected and all parts are downloaded
//
// caching:
// - HF repos: uses HuggingFace cache
// - URLs: uses ETag-based caching
//
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
//
// returns result with model_path and mmproj_path (empty on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const std::string & bearer_token,
bool offline,
const common_download_model_opts & opts = {},
const common_header_list & headers = {}
);
@ -73,11 +79,13 @@ std::vector<common_cached_model_info> common_list_cached_models();
// download single file from url to local path
// returns status code or -1 on error
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {});
const common_header_list & headers = {},
bool skip_etag = false);
// resolve and download model from Docker registry
// return local path to downloaded model file

644
common/hf-cache.cpp Normal file
View File

@ -0,0 +1,644 @@
#include "hf-cache.h"
#include "common.h"
#include "log.h"
#include "http.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <filesystem>
#include <fstream>
#include <atomic>
#include <regex> // migration only
#include <string>
#include <string_view>
#include <stdexcept>
namespace nl = nlohmann;
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#define HOME_DIR "USERPROFILE"
#include <windows.h>
#else
#define HOME_DIR "HOME"
#endif
namespace hf_cache {
namespace fs = std::filesystem;
static fs::path get_cache_directory() {
static const fs::path cache = []() {
struct {
const char * var;
fs::path path;
} entries[] = {
{"HF_HUB_CACHE", fs::path()},
{"HUGGINGFACE_HUB_CACHE", fs::path()},
{"HF_HOME", fs::path("hub")},
{"XDG_CACHE_HOME", fs::path("huggingface") / "hub"},
{HOME_DIR, fs::path(".cache") / "huggingface" / "hub"}
};
for (const auto & entry : entries) {
if (auto * p = std::getenv(entry.var); p && *p) {
fs::path base(p);
return entry.path.empty() ? base : base / entry.path;
}
}
throw std::runtime_error("Failed to determine HF cache directory");
}();
return cache;
}
static std::string folder_name_to_repo(const std::string & folder) {
constexpr std::string_view prefix = "models--";
if (folder.rfind(prefix, 0)) {
return {};
}
std::string result = folder.substr(prefix.length());
string_replace_all(result, "--", "/");
return result;
}
static std::string repo_to_folder_name(const std::string & repo_id) {
constexpr std::string_view prefix = "models--";
std::string result = std::string(prefix) + repo_id;
string_replace_all(result, "/", "--");
return result;
}
static fs::path get_repo_path(const std::string & repo_id) {
return get_cache_directory() / repo_to_folder_name(repo_id);
}
static bool is_hex_char(const char c) {
return (c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f') ||
(c >= '0' && c <= '9');
}
static bool is_hex_string(const std::string & s, size_t expected_len) {
if (s.length() != expected_len) {
return false;
}
for (const char c : s) {
if (!is_hex_char(c)) {
return false;
}
}
return true;
}
static bool is_alphanum(const char c) {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9');
}
static bool is_special_char(char c) {
return c == '/' || c == '.' || c == '-';
}
// base chars [A-Za-z0-9_] are always valid
// special chars [/.-] must be surrounded by base chars
// exactly one '/' required
static bool is_valid_repo_id(const std::string & repo_id) {
if (repo_id.empty() || repo_id.length() > 256) {
return false;
}
int slash = 0;
bool special = true;
for (const char c : repo_id) {
if (is_alphanum(c) || c == '_') {
special = false;
} else if (is_special_char(c)) {
if (special) {
return false;
}
slash += (c == '/');
special = true;
} else {
return false;
}
}
return !special && slash == 1;
}
static bool is_valid_hf_token(const std::string & token) {
if (token.length() < 37 || token.length() > 256 ||
!string_starts_with(token, "hf_")) {
return false;
}
for (size_t i = 3; i < token.length(); ++i) {
if (!is_alphanum(token[i])) {
return false;
}
}
return true;
}
static bool is_valid_commit(const std::string & hash) {
return is_hex_string(hash, 40);
}
static bool is_valid_oid(const std::string & oid) {
return is_hex_string(oid, 40) || is_hex_string(oid, 64);
}
static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) {
if (subpath.is_absolute()) {
return false; // never do a / b with b absolute
}
auto b = fs::absolute(path).lexically_normal();
auto t = (b / subpath).lexically_normal();
auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end());
return b_end == b.end();
}
static void safe_write_file(const fs::path & path, const std::string & data) {
fs::path path_tmp = path.string() + ".tmp";
if (path.has_parent_path()) {
fs::create_directories(path.parent_path());
}
std::ofstream file(path_tmp);
file << data;
file.close();
std::error_code ec;
if (!file.fail()) {
fs::rename(path_tmp, path, ec);
}
if (file.fail() || ec) {
fs::remove(path_tmp, ec);
throw std::runtime_error("failed to write file: " + path.string());
}
}
static nl::json api_get(const std::string & url,
const std::string & token) {
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {
{"User-Agent", "llama-cpp/" + build_info},
{"Accept", "application/json"}
};
if (is_valid_hf_token(token)) {
headers.emplace("Authorization", "Bearer " + token);
} else if (!token.empty()) {
LOG_WRN("%s: invalid token, authentication disabled\n", __func__);
}
if (auto res = cli.Get(parts.path, headers)) {
auto body = res->body;
if (res->status == 200) {
return nl::json::parse(res->body);
}
try {
body = nl::json::parse(res->body)["error"].get<std::string>();
} catch (...) { }
throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body);
} else {
throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error()));
}
}
static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) {
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
if (!json.is_object() ||
!json.contains("branches") || !json["branches"].is_array()) {
LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str());
return {};
}
fs::path refs_path = get_repo_path(repo_id) / "refs";
std::string name;
std::string commit;
for (const auto & branch : json["branches"]) {
if (!branch.is_object() ||
!branch.contains("name") || !branch["name"].is_string() ||
!branch.contains("targetCommit") || !branch["targetCommit"].is_string()) {
continue;
}
std::string _name = branch["name"].get<std::string>();
std::string _commit = branch["targetCommit"].get<std::string>();
if (!is_valid_subpath(refs_path, _name)) {
LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str());
continue;
}
if (!is_valid_commit(_commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str());
continue;
}
if (_name == "main") {
name = _name;
commit = _commit;
break;
}
if (name.empty() || commit.empty()) {
name = _name;
commit = _commit;
}
}
if (name.empty() || commit.empty()) {
LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str());
return {};
}
safe_write_file(refs_path / name, commit);
return commit;
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return {};
}
hf_files get_repo_files(const std::string & repo_id,
const std::string & token) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
std::string commit = get_repo_commit(repo_id, token);
if (commit.empty()) {
LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str());
return {};
}
fs::path blobs_path = get_repo_path(repo_id) / "blobs";
fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit;
hf_files files;
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
if (!json.is_array()) {
LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str());
return {};
}
for (const auto & item : json) {
if (!item.is_object() ||
!item.contains("type") || !item["type"].is_string() || item["type"] != "file" ||
!item.contains("path") || !item["path"].is_string()) {
continue;
}
hf_file file;
file.repo_id = repo_id;
file.path = item["path"].get<std::string>();
if (!is_valid_subpath(commit_path, file.path)) {
LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str());
continue;
}
if (item.contains("lfs") && item["lfs"].is_object()) {
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
file.oid = item["lfs"]["oid"].get<std::string>();
}
} else if (item.contains("oid") && item["oid"].is_string()) {
file.oid = item["oid"].get<std::string>();
}
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
continue;
}
file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path;
fs::path final_path = commit_path / file.path;
file.final_path = final_path.string();
if (!file.oid.empty() && !fs::exists(final_path)) {
fs::path local_path = blobs_path / file.oid;
file.local_path = local_path.string();
} else {
file.local_path = file.final_path;
}
files.push_back(file);
}
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return files;
}
static std::string get_cached_ref(const fs::path & repo_path) {
fs::path refs_path = repo_path / "refs";
if (!fs::is_directory(refs_path)) {
return {};
}
std::string fallback;
for (const auto & entry : fs::directory_iterator(refs_path)) {
if (!entry.is_regular_file()) {
continue;
}
std::ifstream f(entry.path());
std::string commit;
if (!f || !std::getline(f, commit) || commit.empty()) {
continue;
}
if (!is_valid_commit(commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str());
continue;
}
if (entry.path().filename() == "main") {
return commit;
}
if (fallback.empty()) {
fallback = commit;
}
}
return fallback;
}
hf_files get_cached_files(const std::string & repo_id) {
fs::path cache_dir = get_cache_directory();
if (!fs::exists(cache_dir)) {
return {};
}
if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
hf_files files;
for (const auto & repo : fs::directory_iterator(cache_dir)) {
if (!repo.is_directory()) {
continue;
}
fs::path snapshots_path = repo.path() / "snapshots";
if (!fs::exists(snapshots_path)) {
continue;
}
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
if (!is_valid_repo_id(_repo_id)) {
continue;
}
if (!repo_id.empty() && _repo_id != repo_id) {
continue;
}
std::string commit = get_cached_ref(repo.path());
fs::path commit_path = snapshots_path / commit;
if (commit.empty() || !fs::is_directory(commit_path)) {
continue;
}
for (const auto & entry : fs::recursive_directory_iterator(commit_path)) {
if (!entry.is_regular_file() && !entry.is_symlink()) {
continue;
}
fs::path path = entry.path().lexically_relative(commit_path);
if (!path.empty()) {
hf_file file;
file.repo_id = _repo_id;
file.path = path.generic_string();
file.local_path = entry.path().string();
file.final_path = file.local_path;
files.push_back(std::move(file));
}
}
}
return files;
}
std::string finalize_file(const hf_file & file) {
static std::atomic<bool> symlinks_disabled{false};
std::error_code ec;
fs::path local_path(file.local_path);
fs::path final_path(file.final_path);
if (local_path == final_path || fs::exists(final_path, ec)) {
return file.final_path;
}
if (!fs::exists(local_path, ec)) {
return file.final_path;
}
fs::create_directories(final_path.parent_path(), ec);
if (!symlinks_disabled) {
fs::path target = fs::relative(local_path, final_path.parent_path(), ec);
if (!ec) {
fs::create_symlink(target, final_path, ec);
}
if (!ec) {
return file.final_path;
}
}
if (!symlinks_disabled.exchange(true)) {
LOG_WRN("%s: failed to create symlink: %s\n", __func__, ec.message().c_str());
LOG_WRN("%s: switching to degraded mode\n", __func__);
}
fs::rename(local_path, final_path, ec);
if (ec) {
LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str());
fs::copy(local_path, final_path, ec);
if (ec) {
LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str());
}
}
return file.final_path;
}
// delete everything after this line, one day
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
std::smatch match;
if (std::regex_match(filename, match, re)) {
return {match[1].str(), match[2].str()};
}
return {};
}
static std::string make_old_cache_filename(const std::string & owner,
const std::string & repo,
const std::string & filename) {
auto result = owner + "_" + repo + "_" + filename;
string_replace_all(result, "/", "_");
return result;
}
static bool migrate_single_file(const fs::path & old_cache,
const std::string & owner,
const std::string & repo,
const nl::json & node,
const hf_files & files) {
if (!node.contains("rfilename") ||
!node.contains("lfs") ||
!node["lfs"].contains("sha256")) {
return false;
}
std::string path = node["rfilename"];
std::string sha256 = node["lfs"]["sha256"];
const hf_file * file_info = nullptr;
for (const auto & f : files) {
if (f.path == path) {
file_info = &f;
break;
}
}
std::string old_filename = make_old_cache_filename(owner, repo, path);
fs::path old_path = old_cache / old_filename;
fs::path etag_path = old_path.string() + ".etag";
if (!fs::exists(old_path)) {
if (fs::exists(etag_path)) {
LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str());
fs::remove(etag_path);
}
return false;
}
bool delete_old_path = false;
if (!file_info) {
LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str());
delete_old_path = true;
} else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) {
LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str());
delete_old_path = true;
}
std::error_code ec;
if (delete_old_path) {
fs::remove(old_path, ec);
fs::remove(etag_path, ec);
return true;
}
fs::path new_path(file_info->local_path);
fs::create_directories(new_path.parent_path(), ec);
if (!fs::exists(new_path, ec)) {
fs::rename(old_path, new_path, ec);
if (ec) {
fs::copy_file(old_path, new_path, ec);
if (ec) {
LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str());
return false;
}
}
fs::remove(old_path, ec);
}
fs::remove(etag_path, ec);
std::string filename = finalize_file(*file_info);
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
return true;
}
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
fs::path old_cache = fs_get_cache_directory();
if (!fs::exists(old_cache)) {
return;
}
if (offline) {
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
return; // -hf is not going to work
}
bool warned = false;
for (const auto & entry : fs::directory_iterator(old_cache)) {
if (!entry.is_regular_file()) {
continue;
}
auto filename = entry.path().filename().string();
auto [owner, repo] = parse_manifest_name(filename);
if (owner.empty() || repo.empty()) {
continue;
}
if (!warned) {
warned = true;
LOG_WRN("================================================================================\n"
"WARNING: Migrating cache to HuggingFace cache directory\n"
" Old cache: %s\n"
" New cache: %s\n"
"This one-time migration moves models previously downloaded with -hf\n"
"from the legacy llama.cpp cache to the standard HuggingFace cache.\n"
"Models downloaded with --model-url are not affected.\n"
"================================================================================\n",
old_cache.string().c_str(), get_cache_directory().string().c_str());
}
auto repo_id = owner + "/" + repo;
auto files = get_repo_files(repo_id, token);
if (files.empty()) {
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
continue;
}
try {
std::ifstream manifest(entry.path());
auto json = nl::json::parse(manifest);
for (const char * key : {"ggufFile", "mmprojFile"}) {
if (json.contains(key)) {
migrate_single_file(old_cache, owner, repo, json[key], files);
}
}
} catch (const std::exception & e) {
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
continue;
}
fs::remove(entry.path());
}
}
} // namespace hf_cache

35
common/hf-cache.h Normal file
View File

@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <vector>
// Ref: https://huggingface.co/docs/hub/local-cache.md
namespace hf_cache {
struct hf_file {
std::string path;
std::string url;
std::string local_path;
std::string final_path;
std::string oid;
std::string repo_id;
};
using hf_files = std::vector<hf_file>;
// Get files from HF API
hf_files get_repo_files(
const std::string & repo_id,
const std::string & token
);
hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
} // namespace hf_cache

View File

@ -53,6 +53,13 @@ private:
return tokens[current + offset];
}
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) {
const auto & t = peek();
if (t.t != type) {
@ -90,9 +97,9 @@ private:
size_t start_pos = current;
switch (peek().t) {
case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
return mk_stmt<comment_statement>(start_pos, next().value);
case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value);
return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement:
return parse_jinja_statement();
case token::open_expression:
@ -119,8 +126,7 @@ private:
}
size_t start_pos = current;
std::string name = peek().value;
current++; // consume identifier
std::string name = next().value;
statement_ptr result;
if (name == "set") {
@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos);
current++;
++current;
} else {
throw std::runtime_error("Unknown statement: " + name);
@ -217,7 +223,7 @@ private:
statements body;
if (is(token::equals)) {
current++;
++current;
value = parse_expression_sequence();
} else {
// parsing multiline set here
@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma);
while (is(token::comma)) {
current++; // consume comma
++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
}
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++;
++current; // consume 'in'
// `messages` in `for message in messages`
auto iterable = parse_expression();
@ -305,7 +311,8 @@ private:
}
if (is_statement({"else"})) {
current += 2;
++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) {
alternate.push_back(parse_any());
@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression();
while (is_identifier("or")) {
size_t start_pos = current;
token op = tokens[current++];
token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
}
return left;
@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression();
while (is_identifier("and")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
}
return left;
@ -367,7 +374,7 @@ private:
// Try parse unary operators
if (is_identifier("not")) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
}
return parse_comparison_expression();
@ -382,11 +389,12 @@ private:
size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos};
current += 2;
++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) {
op = tokens[current++];
op = next();
} else if (is(token::comparison_binary_operator)) {
op = tokens[current++];
op = next();
} else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
}
@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
}
return left;
@ -407,7 +415,7 @@ private:
auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current;
auto op = tokens[current++];
auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
}
return left;
@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression();
while (is_identifier("is")) {
size_t start_pos = current;
current++;
++current; // consume 'is'
bool negate = false;
if (is_identifier("not")) { current++; negate = true; }
if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression();
while (is(token::pipe)) {
size_t start_pos = current;
current++;
++current; // consume pipe
auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++];
auto op = next();
bool computed = op.t == token::open_square_bracket;
statement_ptr prop;
if (computed) {
@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() {
size_t start_pos = current;
auto t = tokens[current++];
auto t = next();
switch (t.t) {
case token::numeric_literal:
if (t.value.find('.') != std::string::npos) {
@ -547,7 +555,7 @@ private:
case token::string_literal: {
std::string val = t.value;
while (is(token::string_literal)) {
val += tokens[current++].value;
val += next().value;
}
return mk_stmt<string_literal>(start_pos, val);
}
@ -562,9 +570,9 @@ private:
statements vals;
while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression());
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<array_literal>(start_pos, std::move(vals));
}
case token::open_curly_bracket: {
@ -573,9 +581,9 @@ private:
auto key = parse_expression();
expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++;
if (is(token::comma)) ++current;
}
current++;
++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs));
}
default:

View File

@ -31,10 +31,10 @@ import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab
try:
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer,
)
@ -45,9 +45,9 @@ except ImportError:
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
_mistral_common_installed = False
TokenizerVersion = None
Tekkenizer = None
SentencePieceTokenizer = None
TokenizerVersion: Any = None
Tekkenizer: Any = None
SentencePieceTokenizer: Any = None
_mistral_import_error_msg = (
"Mistral format requires `mistral-common` to be installed. Please run "
"`pip install mistral-common[image,audio]` to install it."
@ -145,6 +145,7 @@ class ModelBase:
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self._is_nvfp4 = False
self._is_mxfp4 = False
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@ -220,7 +221,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys())
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
part_names = sorted(part_dict.keys())
else:
weight_map = {}
@ -739,6 +740,7 @@ class ModelBase:
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_method = (self.hparams.get("quantization_config") or {}).get("quant_method")
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
quant_config_file = self.dir_model / "hf_quant_config.json"
@ -755,6 +757,7 @@ class ModelBase:
quant_algo = "NVFP4"
self._is_nvfp4 = quant_algo == "NVFP4"
self._is_mxfp4 = quant_method == "mxfp4"
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
@ -903,6 +906,12 @@ class ModelBase:
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
if self.ftype in (gguf.LlamaFileType.ALL_F32, gguf.LlamaFileType.MOSTLY_F16, gguf.LlamaFileType.MOSTLY_BF16):
if self._is_nvfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_NVFP4
elif self._is_mxfp4:
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
# Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None and total_params > 0:
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
@ -4291,6 +4300,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
min_dynamic_tiles: int = 0
max_dynamic_tiles: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0)
self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0)
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
@ -4313,6 +4332,11 @@ class InternVisionModel(MmprojModel):
downsample_ratio = self.global_config.get("downsample_ratio")
assert downsample_ratio is not None
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
# older models may not have min/max_dynamic_patch in config
if self.min_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles)
if self.max_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name:
@ -6000,7 +6024,7 @@ class InternLM2Model(TextModel):
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
@ -6321,7 +6345,7 @@ class BertModel(TextModel):
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
@ -8998,7 +9022,7 @@ class T5Model(TextModel):
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -9135,7 +9159,7 @@ class T5EncoderModel(TextModel):
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -11243,8 +11267,7 @@ class GptOssModel(TextModel):
# TODO: remove once MXFP4 is supported more generally
def dequant_model(self):
quant_config = self.hparams.get("quantization_config")
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
if self._is_mxfp4:
return
return super().dequant_model()
@ -12397,6 +12420,7 @@ class LazyTorchTensor(gguf.LazyBase):
kwargs = {}
if func is torch.Tensor.numpy:
assert len(args)
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)

View File

@ -112,11 +112,11 @@ class Tensor:
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
assert name_len < 4096, 'Absurd tensor name length'
quant = gguf.GGML_QUANT_SIZES.get(dtype)
self.dtype = gguf.GGMLQuantizationType(dtype)
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant
offset += 12
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len])

View File

@ -199,10 +199,13 @@ class LoraTorchTensor:
kwargs = {}
if func is torch.permute:
assert len(args)
return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape:
assert len(args)
return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack:
assert len(args)
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
@ -211,6 +214,7 @@ class LoraTorchTensor:
torch.stack([b._lora_B for b in args[0]], dim),
)
elif func is torch.cat:
assert len(args)
assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0)
assert dim == 0
@ -362,7 +366,7 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)
class LoraModel(model_class):
class LoraModel(model_class): # ty: ignore[unsupported-base]
model_arch = model_class.model_arch
lora_alpha: float

View File

@ -12,9 +12,9 @@ Legend:
- 🟡 Partially supported by this backend
- ❌ Not supported by this backend
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
| Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -23,63 +23,63 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -91,31 +91,31 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

View File

@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
return f'({result})?' if min_items == 0 else result
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
has_min = min_value != None
has_max = max_value != None
def digit_range(from_char: str, to_char: str):
out.append("[")
if from_char == to_char:
@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
out.append(to_str[i])
out.append("]")
if has_min and has_max:
if min_value is not None and max_value is not None:
if min_value < 0 and max_value < 0:
out.append("\"-\" (")
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
less_decimals = max(decimals_left - 1, 1)
if has_min:
if min_value is not None:
if min_value < 0:
out.append("\"-\" (")
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
more_digits(length - 1, less_decimals)
return
if has_max:
if max_value is not None:
if max_value >= 0:
if top_level:
out.append("\"-\" [1-9] ")

View File

@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print("Using SentenceTransformer to apply all numbered layers")
model = SentenceTransformer(model_path)
tokenizer = model.tokenizer
config = model[0].auto_model.config # type: ignore
config = model[0].auto_model.config
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window
if hasattr(model.config, 'sliding_window'): # type: ignore
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
if hasattr(model.config, 'sliding_window'):
print(f"Model's sliding_window: {model.config.sliding_window}")
else:
print("Model config does not have sliding_window attribute")
@ -152,7 +152,7 @@ def main():
device = next(model.parameters()).device
else:
# For SentenceTransformer, get device from the underlying model
device = next(model[0].auto_model.parameters()).device # type: ignore
device = next(model[0].auto_model.parameters()).device
model_name = os.path.basename(model_path)
@ -177,7 +177,7 @@ def main():
print(f"{token_id:6d} -> '{token_str}'")
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
else:
# Standard approach: use base model output only
encoded = tokenizer(
@ -205,12 +205,12 @@ def main():
print(f"Embedding dimension: {all_embeddings.shape[1]}")
if len(all_embeddings.shape) == 1:
n_embd = all_embeddings.shape[0] # type: ignore
n_embd = all_embeddings.shape[0]
n_embd_count = 1
all_embeddings = all_embeddings.reshape(1, -1)
else:
n_embd = all_embeddings.shape[1] # type: ignore
n_embd_count = all_embeddings.shape[0] # type: ignore
n_embd = all_embeddings.shape[1]
n_embd_count = all_embeddings.shape[0]
print()

View File

@ -2,7 +2,7 @@
import argparse
import sys
from common import compare_tokens # type: ignore
from common import compare_tokens # type: ignore[import-not-found]
def parse_arguments():

View File

@ -6,7 +6,7 @@ import re
from copy import copy
from enum import Enum
from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse
from pydantic import BaseModel, create_model
@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a type annotation
if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation")
raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
# Find the parameter's description in the docstring
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a description
if not param_doc or not param_doc.description:
raise ValueError(
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
# Add parameter details to the schema
param_docs.append((param.name, param_doc))
@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description
@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
fields[field_name] = (List[array_type], ...)
fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
else:
fields[field_name] = (list, ...)
elif field_type == "object":

View File

@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16);
switch (src0->type) {
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if (src0->type == dst->type) {
@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
ggml_cann_mat_mul_fp(ctx, dst);
break;
case GGML_TYPE_Q4_0:
@ -3005,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
}
}
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_imrope || mrope_used) {
is_neox = true;
}
int64_t rope_dims = n_dims;
if (is_vision) {
rope_dims = src0->ne[0];
}
// Run the full cache init on the non-captured stream. This performs all
// host-to-device memcpy, aclrtMalloc/Free, and on-device computations
// so that the memory pool is warmed up and cache metadata is populated.
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
mrope_used, is_imrope, is_vision, rope_dims);
// Reset `cached` so that during graph capture the on-device computations
// (sin/cos, position multiply, repeat, etc.) still execute and get recorded
// into the captured graph. The cache metadata (theta_scale_length,
// theta_scale, sections, position_length, etc.) remains set, which causes
// all host-to-device copy and malloc/free branches to be skipped.
ctx.rope_cache.cached = false;
}
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];

View File

@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Pre-load the RoPE cache before ACL graph capture.
*
* This function must be called outside of graph capture to perform
* host-to-device memory copies and device memory allocations that are
* not allowed on a captured stream. After pre-loading, the rope cache
* metadata is updated so that the subsequent call to
* aclnn_rope_cache_init (inside graph capture) skips these operations
* and only records the on-device computations into the captured graph.
*
* @param ctx CANN backend context.
* @param dst A ROPE destination tensor from the computation graph.
*/
void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the index of the maximum value along the specified dimension
* of a ggml tensor using the CANN backend.

View File

@ -277,7 +277,7 @@ struct ggml_graph_node_properties {
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;

View File

@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset, ctx->device);
@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
} else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
// NZ format weight are not support quantized yet.
// If ND tensor transform to NZ, size may changed.
int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
@ -2223,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
cann_ctx->graph_lru_cache.push(new_graph);
// Pre-load rope cache before graph capture. During capture the
// stream cannot perform host-to-device memcpy or device memory
// malloc/free. Running the full cache init now populates the
// cache metadata so these branches are skipped during capture,
// while also warming up the memory pool.
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (node->op == GGML_OP_ROPE) {
ggml_cann_rope_cache_preload(*cann_ctx, node);
break;
}
}
}
}
#else
@ -2283,6 +2298,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MUL_MAT:
{
switch (op->src[0]->type) {
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
@ -2320,6 +2338,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_Q8_0:
return true;
default:
@ -2332,6 +2353,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;
@ -2341,20 +2365,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_CPY:
{
ggml_tensor * src = op->src[0];
#ifdef ASCEND_310P
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
// only support F32 and F16.
// only support F32 and F16 on 310P.
return false;
}
#else
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
// only support F32, F16 and BF16.
return false;
}
#endif
return true;
}
break;
case GGML_OP_CONT:
{
// TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;

View File

@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
private:
__attribute__((always_inline))
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
}
}
__attribute__((always_inline))
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);

View File

@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
list(APPEND GGML_SOURCES_CUDA
template-instances/fattn-vec-instance-f16-f16.cu
template-instances/fattn-vec-instance-q4_0-q4_0.cu
template-instances/fattn-vec-instance-q8_0-q8_0.cu
template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
ggml_add_backend_library(ggml-cuda

View File

@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
#ifdef GGML_USE_HIP
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#else
#if __CUDA_ARCH__ >= 800
return __bfloat1622float2(x);
#else
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
#endif // __CUDA_ARCH__ >= 800
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP

View File

@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
return sum;
}
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
__align__(16) nv_bfloat162 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef V_DOT2_F32_F16_AVAILABLE
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
#else
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
return sum;
}
template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
}
}
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
static_assert(ne % 2 == 0, "bad ne");
__align__(16) nv_bfloat162 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
}
}
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_q4_0 * x = (const block_q4_0 *) vx;
@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else {
static_assert(type_K == -1, "bad type");
return nullptr;
@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q5_1<T, ne>;
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else {
static_assert(type_V == -1, "bad type");
return nullptr;

View File

@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
#endif // GGML_USE_HIP
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
half2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
if constexpr (type_V == GGML_TYPE_BF16) {
float2 tmp_f[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp_f,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
}
} else {
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
}
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

View File

@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS
GGML_ABORT("fatal error");
@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
break;
default:
return BEST_FATTN_KERNEL_NONE;

View File

@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
}
}
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
return 1;
return small_k ? nwarps : 1;
case 2:
case 3:
case 4:
@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
}
} break;
case 2: {
constexpr int c_ncols_dst = 2;

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);

View File

@ -5,7 +5,7 @@ import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.

View File

@ -461,7 +461,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
d[7] = x[i * 8 + 7].d;
}
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_q4x4x2(y, i, k);
}
@ -480,7 +480,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_d = y + qrow_size; // then scales
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_q4x4x2(y, i, k);
}
@ -796,7 +796,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
d[7] = x[i * 8 + 7].d;
}
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_q8x4x2(y, i, k);
}
@ -814,7 +814,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_d = y + qrow_size; // then scales
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_q8x4x2(y, i, k);
}
@ -1149,7 +1149,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
e[7] = x[i * 8 + 7].e;
}
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_mxfp4x4x2(y, i, k);
}
@ -1168,7 +1168,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
const uint8_t * y_q = y + 0; // quants first
const uint8_t * y_e = y + qrow_size; // then scales
if (opt_verbose > 1) {
if (opt_verbose > 2) {
for (int i = 0; i < nb; i++) {
dump_packed_block_mxfp4x4x2(y, i, k);
}

View File

@ -24,28 +24,26 @@
// Context for binary operations
struct htp_binary_context {
struct htp_ops_context * octx;
struct fastdiv_values dim1_div;
struct fastdiv_values dim2_div;
struct fastdiv_values dim12_div;
struct fastdiv_values src0_dim1_div; // ne01
struct fastdiv_values src0_dim2_div; // ne02
struct fastdiv_values src0_dim12_div;// ne03
struct fastdiv_values src1_dim1_div; // ne11
struct fastdiv_values src1_dim2_div; // ne12
struct fastdiv_values src1_dim3_div; // ne13
uint32_t nrows_per_thread;
bool split_at_ne01;
bool split_at_ne02;
// Precomputed values
uint32_t block_max;
uint32_t nrows_per_thread;
size_t src0_row_size_aligned;
size_t src1_row_size_aligned;
size_t dst_row_size_aligned;
uint32_t src1_fetch_rows; // 1 or block_max
uint32_t src1_dma_stride; // 0 or stride
bool split_at_ne01;
bool split_at_ne02;
};
#define htp_binary_preamble \
#define htp_binary_preamble \
const struct htp_tensor * src0 = &octx->src0; \
const struct htp_tensor * src1 = &octx->src1; \
struct htp_tensor * dst = &octx->dst; \
@ -72,12 +70,11 @@ struct htp_binary_context {
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
uint32_t ne01, uint32_t ne02) {
static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) {
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01;
uint32_t rows_left = end_row - ir;
@ -191,6 +188,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return;
FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01;
// src1 indices (broadcast/repeat)
@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p02 = fastdiv(prem, &bctx->src0_dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
@ -282,6 +281,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return;
FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01;
uint32_t i13 = (ne13 == 1) ? 0 : i03;
@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint32_t i11 = (ne11 == 1) ? 0 : i01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
for (uint32_t ir = start_row; ir < end_row; ) {
uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
}
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
i03 = fastdiv(ir, &bctx->src0_dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i02 = fastdiv(rem, &bctx->src0_dim1_div);
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p02 = fastdiv(prem, &bctx->src0_dim1_div);
p01 = prem - p02 * ne01;
uint32_t p13 = (ne13 == 1) ? 0 : p03;
@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return;
FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint32_t ir_prefetch = start_row;
int spad_idx = 0;
void * s1_ptr = (void *) src1_spad;
void * s1_ptr = (void *) src1_spad_base;
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
for (uint32_t ir = start_row; ir < end_row; ) {
uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
for (uint32_t r = 0; r < current_block_size; r++) {
@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
uint32_t rem = ir - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
uint32_t p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
@ -458,14 +458,16 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return;
FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row;
@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
uint32_t rem = ir - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r;
@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
uint32_t p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
@ -545,14 +544,16 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
if (start_row >= end_row) return;
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row;
@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
uint32_t rem = ir - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r;
@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
uint32_t p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nb02 = src0->nb[2];
const uint32_t nb03 = src0->nb[3];
const uint32_t nb11 = src1->nb[1]; // src1 row stride
const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
dma_queue * q = octx->ctx->dma[ith];
uint32_t ir_prefetch = start_row;
@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
rem = ir_prefetch - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
uint32_t i03, i02, i01, rem;
i03 = fastdiv(ir, &bctx->dim12_div);
rem = ir - i03 * (ne02 * ne01);
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
uint32_t rem = ir - i03 * (ne02 * ne01);
uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
uint32_t i01 = rem - i02 * ne01;
for (uint32_t r = 0; r < current_block_size; r++) {
uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint32_t p03, p02, p01, prem;
p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
prem = ir_prefetch - p03 * (ne02 * ne01);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div);
uint32_t p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
ir_prefetch += next_block_size;
@ -739,40 +735,36 @@ static int execute_op_binary(struct htp_ops_context * octx) {
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const size_t src0_row_size = src0->ne[0] * elem_size;
const size_t src1_row_size = src1->ne[0] * elem_size;
const size_t dst_row_size = dst->ne[0] * elem_size;
const size_t dst_row_size = dst->ne[0] * elem_size;
// Align to VLEN
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
bool is_add_id = (octx->op == HTP_OP_ADD_ID);
bool is_scalar = !is_add_id && (src1->ne[0] == 1);
// Determine which kernel we will use to alloc memory and dispatch
bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size);
bool is_same_shape = !is_add_id && !is_scalar && !is_transposed &&
(src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) &&
(src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
(src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
(src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]);
bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]);
size_t spad_row_total;
if (is_scalar) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
} else if (is_row_bcast) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
} else if (use_vector_same) {
if (is_same_shape) {
spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
} else if (is_add_id) {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
} else {
spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
}
size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
// Adjust for static src1 in row_bcast case
if (is_row_bcast) {
size_t needed_static = src1_row_size_aligned;
@ -782,28 +774,26 @@ static int execute_op_binary(struct htp_ops_context * octx) {
}
if (rows_per_buffer < 1) {
FARF(ERROR, "binary: VTCM too small\n");
return HTP_STATUS_VTCM_TOO_SMALL;
FARF(ERROR, "binary: VTCM too small\n");
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
if (is_scalar || use_complex || use_repeat || is_add_id) {
octx->src1_spad.size_per_thread = 0;
} else if (is_row_bcast) {
if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) {
octx->src1_spad.size_per_thread = 0;
} else {
octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
}
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
if (is_row_bcast) {
octx->src1_spad.size = src1_row_size_aligned;
} else {
octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
}
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
return HTP_STATUS_VTCM_TOO_SMALL;
@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) {
}
struct htp_binary_context bctx;
bctx.octx = octx;
bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bctx.block_max = rows_per_buffer;
bctx.octx = octx;
bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bctx.block_max = rows_per_buffer;
bctx.src0_row_size_aligned = src0_row_size_aligned;
bctx.src1_row_size_aligned = src1_row_size_aligned;
bctx.dst_row_size_aligned = dst_row_size_aligned;
bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]);
bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]);
bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
bctx.split_at_ne01 = (src0->ne[2] > 1) &&
((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
bctx.split_at_ne02 = (src0->ne[3] > 1) &&
((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
// Precompute specific kernel parameters
if (use_vector_same) {
bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
}
bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
worker_callback_t worker_func;
if (is_add_id) worker_func = binary_job_add_id;
else if (is_scalar) worker_func = binary_job_scalar;
else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
else if (use_vector_same) worker_func = binary_job_vector_same_shape;
else if (use_complex) worker_func = binary_job_vector_complex;
else worker_func = binary_job_element_repeat;
if (is_add_id) worker_func = binary_job_add_id;
else if (is_scalar) worker_func = binary_job_scalar;
else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
else if (is_same_shape) worker_func = binary_job_vector_same_shape;
else if (is_complex) worker_func = binary_job_vector_complex;
else worker_func = binary_job_element_repeat;
if (is_row_bcast) {
dma_queue_pop(q);

View File

@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) {
q->capacity = capacity;
q->idx_mask = capacity - 1;
q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d));
memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d));
q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr));
memset(q->dptr, 0, capacity * sizeof(dma_ptr));

View File

@ -10,19 +10,84 @@
extern "C" {
#endif
// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date
typedef struct dma_descriptor_1d_s {
void * next;
uint32_t size:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
} dma_descriptor_1d;
#if __HVX_ARCH__ < 75
typedef struct dma_descriptor_2d_s {
void * next;
uint32_t reserved0:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
uint32_t desc_type:8;
uint32_t reserved1:24;
uint32_t row_size:16;
uint32_t nrows:16;
uint32_t src_stride:16;
uint32_t dst_stride:16;
uint32_t src_offset:16;
uint32_t dst_offset:16;
} dma_descriptor_2d;
#else
typedef struct dma_descriptor_2d_s {
void * next;
uint32_t dst_stride:24;
uint32_t desc_size:2;
uint32_t dst_comp:1;
uint32_t src_comp:1;
uint32_t dst_bypass:1;
uint32_t src_bypass:1;
uint32_t order:1;
uint32_t done:1;
void * src;
void * dst;
uint32_t desc_type:8;
uint32_t reserved0:24;
uint32_t row_size:24;
uint32_t nrows_lo:8;
uint32_t nrows_hi:8;
uint32_t src_stride:24;
uint32_t offset:24;
uint32_t reserved1:8;
} dma_descriptor_2d;
#endif
typedef struct {
void *dst;
void *dst;
const void *src;
} dma_ptr;
typedef struct {
hexagon_udma_descriptor_type1_t * desc; // descriptor pointers
hexagon_udma_descriptor_type1_t * tail; // tail pointer
dma_ptr * dptr; // dst/src pointers
uint32_t push_idx;
uint32_t pop_idx;
uint32_t capacity;
uint32_t idx_mask;
dma_descriptor_2d * desc; // descriptor pointers
dma_descriptor_2d * tail; // tail pointer
dma_ptr * dptr; // dst/src pointers
uint32_t push_idx;
uint32_t pop_idx;
uint32_t capacity;
uint32_t idx_mask;
} dma_queue;
dma_queue * dma_queue_create(size_t capacity);
@ -59,71 +124,87 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
return p;
}
static inline bool dma_queue_push(dma_queue * q,
dma_ptr dptr,
size_t dst_row_size,
size_t src_row_size,
size_t width, // width in bytes. number of bytes to transfer per row
size_t nrows) {
#if __HVX_ARCH__ < 73
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 0;
#else
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 1;
#endif
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
FARF(ERROR, "dma-push: queue full\n");
FARF(HIGH, "dma-push: queue full\n");
return false;
}
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 1;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
FARF(HIGH, "dma-push: queue full\n");
return false;
}
dma_descriptor_2d * desc = &q->desc[q->push_idx];
desc->next = NULL;
desc->length = 0;
desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
desc->dstbypass = 1;
desc->srcbypass = 1;
#if __HVX_ARCH__ >= 73
desc->dstbypass = 1;
desc->srcbypass = 1;
#else
desc->dstbypass = 0;
desc->srcbypass = 1;
#endif
desc->order = 0;
desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
desc->reserved0 = 0;
desc->reserved1 = 0;
desc->desc_size = 1; // 2d mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->src_comp = 0;
desc->dst_comp = 0;
desc->order = 1;
desc->done = 0;
desc->src_stride = src_stride;
desc->dst_stride = dst_stride;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->allocation = 0;
desc->padding = 0;
desc->roiwidth = width;
desc->roiheight = nrows;
desc->srcstride = src_row_size;
desc->dststride = dst_row_size;
desc->srcwidthoffset = 0;
desc->dstwidthoffset = 0;
desc->row_size = row_size;
#if __HVX_ARCH__ < 75
desc->desc_type = 0; // 2d (16-bit) mode
desc->nrows = nrows;
desc->src_offset = 0;
desc->dst_offset = 0;
#else
desc->desc_type = 9; // 2d (24-bit) mode
desc->nrows_lo = (nrows & 0xff);
desc->nrows_hi = (nrows >> 8);
desc->offset = 0;
#endif
q->dptr[q->push_idx] = dptr;
dmlink(q->tail, desc);
q->tail = desc;
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q,
dma_ptr dptr,
size_t dst_row_size,
size_t src_row_size,
size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
}
static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q,
dma_ptr dptr,
size_t dst_row_size,
size_t src_row_size,
size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_ptr dptr = { NULL };
@ -131,12 +212,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
return dptr;
}
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
while (1) {
dmpoll();
if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
if (desc->done) {
break;
}
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
@ -175,86 +256,62 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity;
}
// ---------------------------------------------------------------------------
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
// transparently handles values that exceed the 16-bit limit and submits
// chained DMA transtions.
//
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
// single descriptor, preserving async DMA behavior.
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
// one at a time. The first N-1 rows are pushed+popped synchronously;
// the last row is left async so the caller can pop it.
// ---------------------------------------------------------------------------
#define UDMA_MAX_FIELD_VAL 65535u
#if __HVX_ARCH__ < 75
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
// Fast path: everything fits in 16 bits.
if (__builtin_expect(
width <= UDMA_MAX_FIELD_VAL &&
nrows <= UDMA_MAX_FIELD_VAL &&
src_stride <= UDMA_MAX_FIELD_VAL &&
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535.
// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions.
#define DMA_MAX_FIELD_VAL 65535u
static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
// Fast path: everything fits in 16 bits
if (nrows == 0 || __builtin_expect(
row_size <= DMA_MAX_FIELD_VAL &&
nrows <= DMA_MAX_FIELD_VAL &&
src_stride <= DMA_MAX_FIELD_VAL &&
dst_stride <= DMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
}
// Case 2: contiguous block (width == src_stride == dst_stride).
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
if (width == src_stride && width == dst_stride) {
size_t total = width * nrows;
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
while (sub_width > 0 && total % sub_width != 0) {
sub_width -= 128;
}
if (sub_width == 0) {
// Fallback: use original width (must fit) with adjusted nrows.
// This shouldn't happen for 128-aligned DMA sizes.
sub_width = width;
}
size_t sub_nrows = total / sub_width;
// Handle sub_nrows > 65535 by issuing chunked descriptors.
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
size_t rows_done = 0;
while (rows_done < sub_nrows) {
size_t chunk = sub_nrows - rows_done;
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
return false;
rows_done += chunk;
// Complete all chunks without waiting except the last one, so the
// caller's single dma_queue_pop drains the final descriptor.
if (rows_done < sub_nrows)
dma_queue_pop_nowait(q);
}
return true;
// Contiguous block
// Use 1d DMA mode which supports sizes up to 24-bits (16MB)
if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) {
size_t total = row_size * nrows;
return dma_queue_push_single_1d(q, dptr, total);
}
// Case 3: stride overflow — fall back to row-by-row.
// Stride overflow — fall back to row-by-row.
{
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
const uint8_t *src = (const uint8_t *) dptr.src;
uint8_t *dst = (uint8_t *) dptr.dst;
for (size_t r = 0; r < nrows; ++r) {
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
src + r * src_stride);
if (!dma_queue_push(q, p, 0, 0, width, 1))
return false;
if (r + 1 < nrows)
dma_queue_pop_nowait(q);
dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride);
if (!dma_queue_push_single_1d(q, p, row_size))
return false;
if (r + 1 < nrows)
dma_queue_pop(q);
}
return true;
}
}
#else // HVX_ARCH >= 75
static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
// On v75 and up we always use 2d 24-bit mode
return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
}
#endif
static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
}
static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);

View File

@ -727,7 +727,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
if (use_dma_activation) {
const size_t row_bytes = (size_t) params->k * sizeof(float);
const size_t stride_bytes = (size_t) params->act_stride * sizeof(float);
dma_queue_push_chained(ctx->dma[0],
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_f32_act, activation_chunk),
row_bytes, stride_bytes, row_bytes, n_rows);
dma_queue_pop(ctx->dma[0]);
@ -747,7 +747,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
{
const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols);
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
}
@ -765,7 +765,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols);
const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
}
@ -891,7 +891,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
if (use_dma_activation) {
const size_t row_bytes = (size_t) k * sizeof(float);
const size_t stride_bytes = (size_t) act_stride * sizeof(float);
dma_queue_push_chained(ctx->dma[0],
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_f32_act, activation_chunk),
row_bytes, stride_bytes, row_bytes, n_rows);
dma_queue_pop(ctx->dma[0]);
@ -916,7 +916,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
{
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
}
@ -933,7 +933,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
}
@ -1104,7 +1104,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
// because UDMA roiwidth is 16-bit and total size can exceed 65535.
{
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
}
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
@ -1120,7 +1120,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
}
// Dequant + vscatter writes directly to [K, N] transposed tiles.
@ -1173,7 +1173,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
{
// Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
const uint8_t *qweight_chunk_A0 = permuted_weight;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
}
{
@ -1191,7 +1191,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
if (1 < n_chunk_cnt) {
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
}
// C0
@ -1218,7 +1218,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
// issue A_{i+2}
if (i + 2 < n_chunk_cnt) {
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
}
// wait for HMX (C_{i}) -- C_{i} is done
@ -1443,7 +1443,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
{
const float *activation_block = x + mr * k + kk;
dma_queue_push_chained(ctx->dma[0],
dma_queue_push(ctx->dma[0],
dma_make_ptr(vtcm_scratch1, activation_block),
k_blk_sz * sizeof(float),
k * sizeof(float),
@ -1472,10 +1472,10 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
// 2D DMA: quants sub-range
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
s.dst_stride, s.src_stride, s.quant_width, s.n_rows);
// 2D DMA: scales sub-range
dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
s.dst_stride, s.src_stride, s.scale_width, s.n_rows);
}
TIMER_STOP(fetch);

View File

@ -15,12 +15,4 @@
#include "hvx-div.h"
#include "hvx-base.h"
#ifndef GATHER_TYPE
# if defined(__hexagon__)
# define GATHER_TYPE(_a) (intptr_t) _a
# else
# define GATHER_TYPE(_a) (HVX_Vector *) _a
# endif
#endif
#endif /* HVX_UTILS_H */

View File

@ -214,7 +214,7 @@ static int vtcm_alloc(struct htp_context * ctx) {
HAP_compute_res_attr_init(&attr);
HAP_compute_res_attr_set_serialize(&attr, 0);
HAP_compute_res_attr_set_cache_mode(&attr, 1);
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size);
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page
HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
HAP_compute_res_attr_set_hmx_param(&attr, 1);
@ -319,7 +319,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->n_threads = n_hvx;
for (int i = 0; i < ctx->n_threads; i++) {
// see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541
ctx->dma[i] = dma_queue_create(64);
ctx->dma[i] = dma_queue_create(128);
}
// init worker pool

View File

@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
const int dr = scctx->nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
const int ir = ir1 - ir0;
const uint32_t ir = ir1 - ir0;
if (ir0 >= ir1) {
return; // No work for this thread
@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float);
Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);

View File

@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS)
list(APPEND GGML_SOURCES_ROCM ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
list(APPEND GGML_SOURCES_ROCM
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
ggml_add_backend_library(ggml-hip

View File

@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break;
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
@ -1748,6 +1752,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_CONV_3D);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
char base[256];
char name[256];
snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);

View File

@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
@ -1077,6 +1081,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CONV_3D:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
@ -1143,6 +1152,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 320 &&
op->src[0]->ne[0] != 512 &&
op->src[0]->ne[0] != 576) {
return false;
}

View File

@ -120,6 +120,10 @@
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_UNARY_NUM_FLOOR 117
#define OP_UNARY_NUM_CEIL 118
#define OP_UNARY_NUM_ROUND 119
#define OP_UNARY_NUM_TRUNC 120
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
@ -643,6 +647,42 @@ typedef struct {
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col;
typedef struct {
int32_t IW;
int32_t IH;
int32_t ID;
int32_t OW;
int32_t OH;
int32_t OD;
int32_t KW;
int32_t KH;
int32_t KD;
int32_t s0;
int32_t s1;
int32_t s2;
int32_t p0;
int32_t p1;
int32_t p2;
int32_t d0;
int32_t d1;
int32_t d2;
int32_t IC;
int32_t N;
int32_t OC;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_conv_3d;
typedef struct{
int32_t ne00;
uint64_t nb01;

View File

@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
} break;
case GGML_OP_CONV_3D:
{
n_fuse = ggml_metal_op_conv_3d(ctx, idx);
} break;
case GGML_OP_UPSCALE:
{
n_fuse = ggml_metal_op_upscale(ctx, idx);
@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
// 1. Extract standard dimensions and byte strides
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
// 2. Extract hyperparams from op_params
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
const int32_t s2 = ((const int32_t *)(op->op_params))[2];
const int32_t p0 = ((const int32_t *)(op->op_params))[3];
const int32_t p1 = ((const int32_t *)(op->op_params))[4];
const int32_t p2 = ((const int32_t *)(op->op_params))[5];
const int32_t d0 = ((const int32_t *)(op->op_params))[6];
const int32_t d1 = ((const int32_t *)(op->op_params))[7];
const int32_t d2 = ((const int32_t *)(op->op_params))[8];
const int32_t IC = ((const int32_t *)(op->op_params))[9];
const int32_t N = ((const int32_t *)(op->op_params))[10];
const int32_t OC = ((const int32_t *)(op->op_params))[11];
// 3. Build the parameter struct using the macro-generated variables
ggml_metal_kargs_conv_3d args = {
/*.IW =*/ (int32_t)op->src[1]->ne[0],
/*.IH =*/ (int32_t)op->src[1]->ne[1],
/*.ID =*/ (int32_t)op->src[1]->ne[2],
/*.OW =*/ (int32_t)op->ne[0],
/*.OH =*/ (int32_t)op->ne[1],
/*.OD =*/ (int32_t)op->ne[2],
/*.KW =*/ (int32_t)op->src[0]->ne[0],
/*.KH =*/ (int32_t)op->src[0]->ne[1],
/*.KD =*/ (int32_t)op->src[0]->ne[2],
s0, s1, s2,
p0, p1, p2,
d0, d1, d2,
IC, N, OC,
nb00, nb01, nb02, nb03, // Weight strides
nb10, nb11, nb12, nb13, // Input strides
nb0, nb1, nb2, nb3 // Output strides
};
// 4. Fetch the JIT pipeline
auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
// 5. Grid mapping
int nth0 = 32; // Standard SIMD width for Apple Silicon
int nth1 = 1;
int nth2 = 1;
int64_t spatial_volume = args.OW * args.OH * args.OD;
int ntg0 = (spatial_volume + nth0 - 1) / nth0;
int ntg1 = args.OC;
int ntg2 = args.N;
// 6. Bind and Dispatch via the ggml C wrapper
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
return 1;
}
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@ -75,6 +75,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);

View File

@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl(
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
}
#undef FC_OP
@ -4883,6 +4899,98 @@ kernel void kernel_upscale_bilinear_f32(
}
}
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
@ -6177,6 +6285,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
@ -6192,6 +6301,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16)
@ -6208,6 +6318,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>;
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
@ -6224,6 +6335,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
@ -6239,6 +6351,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
@ -6254,6 +6367,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
@ -6269,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
@ -6284,6 +6399,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
#undef FA_TYPES
@ -6865,6 +6981,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16)

View File

@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_SOURCES_MUSA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
list(APPEND GGML_SOURCES_MUSA
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
endif()
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)

View File

@ -89,6 +89,7 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_1_f32
mul_mv_q4_1_f32_flat
mul_mv_q4_k_f32
mul_mv_q4_k_f32_flat
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@ -107,11 +108,14 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q4_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_q4_1_f32
gemm_noshuffle_q4_1_f32
gemv_noshuffle_general_q8_0_f32
gemv_noshuffle_q6_k_f32
gemm_noshuffle_q6_k_f32
mul
neg
norm

View File

@ -529,16 +529,19 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q4_1_noshuffle;
cl_kernel kernel_restore_block_q4_1_noshuffle;
cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q4_1_f32;
cl_kernel kernel_mul_mv_q4_1_f32_flat;
cl_kernel kernel_mul_mv_q4_K_f32;
cl_kernel kernel_mul_mv_q4_K_f32_flat;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@ -578,6 +581,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_k_f32_l4_lm;
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
std::vector<ProfilingInfo> profiling_info;
@ -713,6 +717,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
cl_kernel kernel_mul_mm_q8_0_f32_8x4;
cl_kernel CL_mul_mat_vec_q8_0_f32;
cl_kernel kernel_gemv_noshuffle_q6_K_f32;
cl_kernel kernel_gemm_noshuffle_q6_K_f32;
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
void free() {
@ -917,8 +923,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err));
GGML_LOG_CONT(".");
}
@ -1209,6 +1219,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q4_k_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q4_k_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q4_k_f32_flat.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32_flat", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -1482,6 +1509,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_q4_k_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q4_k_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q4_k_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_k_f32_l4_lm", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mm_q6_k_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -2603,6 +2647,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_noshuffle_q6_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl");
#endif
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable ";
if (backend_ctx->has_vector_subgroup_broadcast) {
CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
}
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err));
GGML_LOG_CONT(".");
}
// gemm_noshuffle_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_noshuffle_q6_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err));
GGML_LOG_CONT(".");
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
GGML_LOG_CONT("\n");
}
@ -3347,6 +3430,40 @@ struct ggml_tensor_extra_cl_q8_0 {
}
};
struct ggml_tensor_extra_cl_q4_K {
// Quantized values
cl_mem q = nullptr;
// Scales for each super block.
cl_mem s = nullptr;
// Scales
cl_mem d = nullptr;
// Min
cl_mem dm = nullptr;
~ggml_tensor_extra_cl_q4_K() {
reset();
}
void reset() {
if (q != nullptr) {
CL_CHECK(clReleaseMemObject(q));
q = nullptr;
}
if (s != nullptr) {
CL_CHECK(clReleaseMemObject(s));
s = nullptr;
}
if (d != nullptr) {
CL_CHECK(clReleaseMemObject(d));
d = nullptr;
}
if (dm != nullptr) {
CL_CHECK(clReleaseMemObject(dm));
dm = nullptr;
}
}
};
struct ggml_tensor_extra_cl_q6_K {
// Lower 4 bits of quantized weights.
cl_mem ql = nullptr;
@ -3956,6 +4073,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) {
delete e;
}
for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {
delete e;
}
@ -4039,6 +4162,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() {
ggml_tensor_extra_cl_q4_K * extra;
if (temp_tensor_extras_q4_K.empty()) {
extra = new ggml_tensor_extra_cl_q4_K();
} else {
extra = temp_tensor_extras_q4_K.back();
temp_tensor_extras_q4_K.pop_back();
}
temp_tensor_extras_q4_K_in_use.push_back(extra);
extra->reset();
return extra;
}
ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
ggml_tensor_extra_cl_q6_K * extra;
if (temp_tensor_extras_q6_K.empty()) {
@ -4080,6 +4218,11 @@ struct ggml_backend_opencl_buffer_context {
}
temp_tensor_extras_q8_0_in_use.clear();
for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) {
temp_tensor_extras_q4_K.push_back(e);
}
temp_tensor_extras_q4_K_in_use.clear();
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
temp_tensor_extras_q6_K.push_back(e);
}
@ -4101,6 +4244,8 @@ struct ggml_backend_opencl_buffer_context {
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K;
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K_in_use;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;
@ -4835,6 +4980,83 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
return;
}
if (tensor->type == GGML_TYPE_Q4_K) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_q4_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_K();
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3 * ggml_blck_size(tensor->type) / 64);
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
GGML_ASSERT(size_d + size_dm + size_s + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_buffer_region region;
// Create subbuffer for d.
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for mins.
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
region.size = size_dm;
extra->dm = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Create subbuffer for s.
region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment);
region.size = size_s;
extra->s = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
@ -4851,61 +5073,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
"Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_mem data_device;
CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err));
CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_buffer_region region;
// Subbuffer for ql
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_ql;
extra->ql = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
auto previous_origin = region.origin;
// Subbuffer for qh
region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);
region.size = size_qh;
extra->qh = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
previous_origin = region.origin;
// Subbuffer for scales
region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);
region.size = size_s;
extra->s = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
previous_origin = region.origin;
// Create subbuffer for d.
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
previous_origin = region.origin;
// Flatten the weights
cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;
cl_kernel kernel;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
kernel = backend_ctx->kernel_convert_block_q6_K;
if (use_adreno_kernels(backend_ctx, tensor)) {
kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle;
}
#else
kernel = backend_ctx->kernel_convert_block_q6_K;
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
cl_uchar mask = 0xff;
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
@ -4919,6 +5138,29 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
extra->size_d = size_d;
tensor->extra = extra;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_kernels(backend_ctx, tensor)) {
cl_int M = tensor->ne[1]; // ne01
cl_int K = tensor->ne[0]; // ne00
// Transpose ql as ushort
transpose_2d_as_16b(backend_ctx,
extra->ql, extra->ql, size_ql, K/4, M);
// Transpose qh as uchar
transpose_2d_as_8b(backend_ctx,
extra->qh, extra->qh, size_qh, K/4, M);
// Transpose s as ushort
transpose_2d_as_16b(backend_ctx,
extra->s, extra->s, size_s, K/16/2, M);
// Transpose d as ushort
transpose_2d_as_16b(backend_ctx,
extra->d, extra->d, size_d, K/256, M);
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
return;
}
#endif // GGML_OPENCL_SOA_Q
@ -5245,24 +5487,111 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
if (tensor->type == GGML_TYPE_Q4_K) {
ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra;
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_kernels(backend_ctx, tensor)) {
static ggml_cl_buffer buf_trans_ql;
static ggml_cl_buffer buf_trans_qh;
static ggml_cl_buffer buf_trans_s;
static ggml_cl_buffer buf_trans_d;
static ggml_cl_buffer buf_unpacked;
cl_int M = tensor->ne[1]; // ne01
cl_int K = tensor->ne[0]; // ne00
GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0);
size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && "Incorrect tensor size");
buf_trans_ql.allocate(backend_ctx->context, size_ql);
buf_trans_qh.allocate(backend_ctx->context, size_qh);
buf_trans_s.allocate(backend_ctx->context, size_s);
buf_trans_d.allocate(backend_ctx->context, size_d);
buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));
// transpose ql, qh, s and d back
transpose_2d_as_16b(backend_ctx, extra->ql, buf_trans_ql.buffer, size_ql, M, K/4);
transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/4);
transpose_2d_as_16b(backend_ctx, extra->s, buf_trans_s.buffer, size_s, M, K/16/2);
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256);
// unpack
cl_uchar mask = 0xFF;
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K_noshuffle;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_ql.buffer));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_s.buffer));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk));
size_t global_work_size[] = {(size_t)n_blk, 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));
return;
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_uchar mask = 0xFF;
cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type);
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk));
size_t global_work_size[] = {(size_t)n_blk, 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
@ -5553,6 +5882,8 @@ typedef struct {
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2,
"wrong q4_0 block size/padding");
#define QK_MXFP4 32
#include <math.h>
#ifdef __cplusplus
#include "half.hpp"
@ -5596,7 +5927,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
buf_d = malloc(size_e);
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->e, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
} else {
// Read out the tensor from GPU memory.
@ -9331,6 +9662,196 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
#endif
}
static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne1 = dst->ne[1];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
cl_context context = backend_ctx->context;
cl_kernel kernel;
cl_int err;
cl_buffer_region region;
cl_image_format img_fmt;
cl_image_desc img_desc;
// subbuffer and image for activation
if (ne1 == 1) {
cl_mem ql_img = nullptr;
cl_mem qh_img = nullptr;
cl_mem b_sub_buffer = nullptr;
cl_mem b_img = nullptr;
// image for ql
img_fmt.image_channel_order = CL_R;
img_fmt.image_channel_data_type = CL_FLOAT;
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = ne01 * ne00 / 8;
img_desc.buffer = extra0_q6_K->ql;
CL_CHECK((ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// image for qh
img_fmt.image_channel_order = CL_R;
img_fmt.image_channel_data_type = CL_HALF_FLOAT;
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = ne01 * ne00 / 8;
img_desc.buffer = extra0_q6_K->qh;
CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
region.origin = offset1;
region.size = ne00 * ne1 * sizeof(float);
CL_CHECK((b_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
img_fmt.image_channel_order = CL_RGBA;
img_fmt.image_channel_data_type = CL_FLOAT;
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = ne00 * ne1 / 4;
img_desc.buffer = b_sub_buffer;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
kernel = backend_ctx->kernel_gemv_noshuffle_q6_K_f32;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &ql_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01));
size_t local_work_size[3] = {64, 4, 1};
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(ql_img));
CL_CHECK(clReleaseMemObject(qh_img));
CL_CHECK(clReleaseMemObject(b_sub_buffer));
CL_CHECK(clReleaseMemObject(b_img));
} else {
cl_mem b_sub_buf;
cl_mem b_buf_trans;
cl_mem b_img;
cl_mem b_img_trans;
// subbuffer for activation
region.origin = offset1;
region.size = ne00 * ne1 * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for activation
img_fmt.image_channel_order = CL_RGBA;
img_fmt.image_channel_data_type = CL_FLOAT;
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = ne00 * ne1 / 4;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// pad N to multiple of 8
int extra_elements = ne1 % 8;
int padding = 0;
if (extra_elements > 0){
padding = 8 - extra_elements;
}
// subbuffer for transposed activation
region.origin = 0;
region.size = ne00 * (ne1 + padding) * sizeof(float)/2;
backend_ctx->prealloc_act_trans.allocate(context, region.size);
CL_CHECK((b_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for transposed activation
img_fmt.image_channel_order = CL_RGBA;
img_fmt.image_channel_data_type = CL_HALF_FLOAT;
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = ne00 * (ne1 + padding) / 4;
img_desc.buffer = b_buf_trans;
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
// transpose activation
int height_B = ne1/4;
if (height_B == 0) {
height_B = 1;
}
int width_B = ne00/4;
int padded_height_B = (ne1 + padding) / 4;
kernel = backend_ctx->kernel_transpose_32_16;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
size_t local_size_t[2] = { 1, 16 };
size_t global_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst);
// gemm
kernel = backend_ctx->kernel_gemm_noshuffle_q6_K_f32;
int padded_N = ne1 + padding;
cl_ushort mask_f000 = 0xF000;
cl_uchar mask_c0 = 0xC0;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &padded_N));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ushort),&mask_f000));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_c0));
size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
size_t local_work_size[3] = {2, 128, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(b_sub_buf));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_buf_trans));
CL_CHECK(clReleaseMemObject(b_img_trans));
}
#else
GGML_UNUSED(backend);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
#endif
}
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@ -9357,6 +9878,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra;
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
#endif
@ -9466,6 +9988,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
return;
}
// q6_K x fp32
if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) {
ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst);
return;
}
// q4_0 x fp32
if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
// TODO: remove duplicate definitions of image description + format -- move to top
@ -10005,6 +10533,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q4_K: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q6_K: {
if (ne11 < 32) {
break;
@ -10449,6 +11021,43 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_q4_K_f32_flat;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 1;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = 16;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
#else
kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
if (backend_ctx->gpu_family == INTEL) {
@ -10482,6 +11091,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
}
case GGML_TYPE_Q5_K:

View File

@ -28,6 +28,7 @@
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_SCALE_SIZE (3 * QK_K / 64)
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
@ -55,6 +56,16 @@ struct block_q4_1 {
uchar qs[QK4_1 / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q4_k
//------------------------------------------------------------------------------
struct block_q4_K {
half d; // delta
half dm; // min
uchar s[K_SCALE_SIZE];
uchar q[QK_K / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
@ -408,6 +419,62 @@ kernel void kernel_restore_block_q8_0_trans(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_K
// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_K(
global struct block_q4_K * src0,
global uchar * dst_q,
global uchar * dst_s,
global half * dst_d,
global half * dst_dm
) {
global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0);
global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * dm = (global half *) dst_dm + get_global_id(0);
*d = b->d;
*dm = b->dm;
for (int i = 0; i < QK_K/2; ++i) {
q[i] = b->q[i];
}
for (int i = 0; i < K_SCALE_SIZE; ++i) {
s[i] = b->s[i];
}
}
// Restore block_q4_K from flattened arrays.
// Each thread processes a super block.
kernel void kernel_restore_block_q4_K(
global uchar * src_q,
global uchar * src_s,
global half * src_d,
global half * src_dm,
global struct block_q4_K * dst
) {
global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0);
global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * dm = (global half *) src_dm + get_global_id(0);
b->d = *d;
b->dm = *dm;
for (int i = 0; i < QK_K/2; ++i) {
b->q[i] = q[i];
}
for (int i = 0; i < K_SCALE_SIZE; ++i) {
b->s[i] = s[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q6_K
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
@ -419,8 +486,13 @@ kernel void kernel_convert_block_q6_K(
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d
global half * dst_d,
uchar mask_lsb_8,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
@ -447,8 +519,13 @@ kernel void kernel_restore_block_q6_K(
global uchar * dst_qh,
global char * dst_s,
global half * dst_d,
global struct block_q6_K * dst
global struct block_q6_K * dst,
uchar mask_lsb_8,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
@ -467,3 +544,117 @@ kernel void kernel_restore_block_q6_K(
b->scales[i] = s[i];
}
}
kernel void kernel_convert_block_q6_K_noshuffle(
global struct block_q6_K * src0,
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d,
uchar mask_lsb_8,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK_K/2/4; ++i) {
uchar x0 = b->ql[i*2 + 0] & mask_lsb_8;
uchar x1 = b->ql[i*2 + 1] & mask_lsb_8;
ql[i + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4);
ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0);
uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8;
uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8;
ql[i + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4);
ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0);
}
for (int i = 0; i < QK_K/4/8; ++i) {
uchar x0 = b->qh[i*4 + 0] & mask_lsb_8;
uchar x1 = b->qh[i*4 + 1] & mask_lsb_8;
uchar x2 = b->qh[i*4 + 2] & mask_lsb_8;
uchar x3 = b->qh[i*4 + 3] & mask_lsb_8;
qh[i + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6);
qh[i + 8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4);
qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2);
qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0);
uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8;
uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8;
uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8;
uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8;
qh[i + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6);
qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4);
qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2);
qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0);
}
for (int i = 0; i < QK_K/16; ++i) {
s[i] = b->scales[i];
}
}
kernel void kernel_restore_block_q6_K_noshuffle(
global uchar * src_ql,
global uchar * src_qh,
global char * src_s,
global half * src_d,
global struct block_q6_K * dst,
uchar mask_lsb_8,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) src_s + QK_K/16*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK_K/2/4; ++i) {
uchar x0 = ql[i + 0] & mask_lsb_8;
uchar x1 = ql[i + 32] & mask_lsb_8;
b->ql[i*2 + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4);
b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0);
uchar x2 = ql[i + 64] & mask_lsb_8;
uchar x3 = ql[i + 96] & mask_lsb_8;
b->ql[i*2 + 0 + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4);
b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0);
}
for (int i = 0; i < QK_K/4/8; ++i) {
uchar x0 = qh[i + 0] & mask_lsb_8;
uchar x1 = qh[i + 8] & mask_lsb_8;
uchar x2 = qh[i + 16] & mask_lsb_8;
uchar x3 = qh[i + 24] & mask_lsb_8;
b->qh[i*4 + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6);
b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4);
b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2);
b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0);
uchar x4 = qh[i + 0 + 32] & mask_lsb_8;
uchar x5 = qh[i + 8 + 32] & mask_lsb_8;
uchar x6 = qh[i + 16 + 32] & mask_lsb_8;
uchar x7 = qh[i + 24 + 32] & mask_lsb_8;
b->qh[i*4 + 0 + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6);
b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4);
b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2);
b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0);
}
for (int i = 0; i < QK_K/16; ++i) {
b->scales[i] = s[i];
}
}

View File

@ -0,0 +1,140 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q6_K_f32(
global const ushort * src0_ql,
global const uchar * src0_qh,
global const ushort * src0_s,
global const half * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int m,
int n,
int k,
int n_no_padding,
ushort mask_f000,
uchar mask_c0
) {
dst = (global float *)( (global char *)dst + offsetd );
int m_4 = m >> 2;
int n_4 = n >> 2;
int gy = get_global_id(0); // n
int gx = get_global_id(1); // m
int gx_2 = gx << 2;
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
half4 dequantized_weights;
global const ushort * ptr_ql = src0_ql + gx_2;
global const uchar * ptr_qh = src0_qh + gx_2;
global const ushort * ptr_s = src0_s + gx_2;
global const half * ptr_d = src0_d + gx_2;
for (int i = 0; i < k; i += 4) {
// load 4x elements (ushort) of ql on M, each ushort contains 4 weights
// 4x ushort correspons to 4 rows on M
ushort4 bits4 = vload4(0, ptr_ql + (i/4)*m); // ql packed in 4s in ushort
uchar4 bits2 = vload4(0, ptr_qh + (i/4)*m); // qh packed in 4s in uchar
// load 4 consecutive scales
char8 scale_s_8 = as_char8(vload4(0, ptr_s + (i/16/2)*m)); // 1 char scale every 16 elements, packed in 2s
char4 scale_s = ((i/16) % 2) == 0 ? scale_s_8.s0246 : scale_s_8.s1357; // transposed as ushort, 2 blocks
half4 scale_d = vload4(0, ptr_d + (i/256)*m); // 1 half scale every 256 elements
// j=0
// load 2x 4 elements of activations on N, corresponding to 8 rows on N
B.s0123 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 0);
B.s4567 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 1);
dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0;
dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s1;
dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s2;
dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=1
B.s0123 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 0);
B.s4567 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 1);
dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2))) - 32.f) * scale_s.s0 * scale_d.s0;
dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2))) - 32.f) * scale_s.s1 * scale_d.s1;
dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2))) - 32.f) * scale_s.s2 * scale_d.s2;
dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2))) - 32.f) * scale_s.s3 * scale_d.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=2
B.s0123 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 0);
B.s4567 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 1);
dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x0F00) >> 8) | (bits2.s0 & 0x30))) - 32.f) * scale_s.s0 * scale_d.s0;
dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x0F00) >> 8) | (bits2.s1 & 0x30))) - 32.f) * scale_s.s1 * scale_d.s1;
dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x0F00) >> 8) | (bits2.s2 & 0x30))) - 32.f) * scale_s.s2 * scale_d.s2;
dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x0F00) >> 8) | (bits2.s3 & 0x30))) - 32.f) * scale_s.s3 * scale_d.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=3
B.s0123 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 0);
B.s4567 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 1);
dequantized_weights.s0 = (convert_half((((bits4.s0 & mask_f000) >> 12) | ((bits2.s0 & mask_c0) >> 2))) - 32.f) * scale_s.s0 * scale_d.s0;
dequantized_weights.s1 = (convert_half((((bits4.s1 & mask_f000) >> 12) | ((bits2.s1 & mask_c0) >> 2))) - 32.f) * scale_s.s1 * scale_d.s1;
dequantized_weights.s2 = (convert_half((((bits4.s2 & mask_f000) >> 12) | ((bits2.s2 & mask_c0) >> 2))) - 32.f) * scale_s.s2 * scale_d.s2;
dequantized_weights.s3 = (convert_half((((bits4.s3 & mask_f000) >> 12) | ((bits2.s3 & mask_c0) >> 2))) - 32.f) * scale_s.s3 * scale_d.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
}
int idx = (gy<<3)*m + (gx<<2);
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}

View File

@ -0,0 +1,293 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define NSUBGROUPS 4
#define SUBGROUP_SIZE 64
#define dequantize_block_acc_bcast_8_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \
float8 shared_y; \
shared_y = sub_group_broadcast(y, 0); \
total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \
total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \
total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \
total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \
total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \
total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \
total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \
total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \
total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \
total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \
total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \
total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \
total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \
total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \
total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \
total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 1); \
total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \
total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \
total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \
total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \
total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \
total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \
total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \
total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \
total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \
total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \
total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \
total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \
total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \
total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \
total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \
total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \
#define dequantize_block_acc_bcast_8_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \
shared_y = sub_group_broadcast(y, 2); \
total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \
total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \
total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \
total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \
total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \
total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \
total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \
total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \
total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \
total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \
total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \
total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \
total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \
total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \
total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \
total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 3); \
total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \
total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \
total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \
total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \
total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \
total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \
total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \
total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \
total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \
total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \
total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \
total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \
total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \
total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \
total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \
total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \
#define dequantize_block_acc_bcast_1_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \
float shared_y; \
shared_y = sub_group_broadcast(y.s0, 0); \
total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 0); \
total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 0); \
total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 0); \
total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 0); \
total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 0); \
total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 0); \
total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 0); \
total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 1); \
total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 1); \
total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 1); \
total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 1); \
total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 1); \
total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 1); \
total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 1); \
total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 1); \
total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \
#define dequantize_block_acc_bcast_1_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \
shared_y = sub_group_broadcast(y.s0, 2); \
total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 2); \
total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 2); \
total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 2); \
total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 2); \
total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 2); \
total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 2); \
total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 2); \
total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 3); \
total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 3); \
total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 3); \
total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 3); \
total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 3); \
total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 3); \
total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 3); \
total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 3); \
total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \
total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \
#if defined(ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_gemv_noshuffle_q6_K_f32(
read_only image1d_buffer_t src0_ql,
read_only image1d_buffer_t src0_qh,
global half2 * src0_s,
global half2 * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int ne00,
int ne01
) {
int grp = get_local_id(1);
int gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
int nb = ne00 / 32;
uint4 reg_a_l;
ushort4 reg_a_h;
half2 reg_d;
char4 reg_s;
float8 reg_b;
float2 total_sum = 0.0f;
int line_stride_a = ne01 / 2;
int block_stride_a = NSUBGROUPS * ne01;
for (int k = grp; k < nb; k += NSUBGROUPS) {
reg_d = src0_d[gid + k/8 * line_stride_a];
reg_s = as_char4(src0_s[gid + k * line_stride_a]);
if (slid < 4) {
reg_b.s0123 = read_imagef(src1, 0 + slid*2 + k*8);
reg_b.s4567 = read_imagef(src1, 1 + slid*2 + k*8);
}
reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*0).x;
reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*1).x;
reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*2).x;
reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*3).x;
reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*0).x);
reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*1).x);
reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*2).x);
reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*3).x);
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantize_block_acc_bcast_8_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b);
#else
dequantize_block_acc_bcast_1_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b);
#endif // VECTOR_SUB_GROUP_BROADCAT
reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*4).x;
reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*5).x;
reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*6).x;
reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*7).x;
reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*4).x);
reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*5).x);
reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*6).x);
reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*7).x);
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantize_block_acc_bcast_8_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b);
#else
dequantize_block_acc_bcast_1_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b);
#endif // VECTOR_SUB_GROUP_BROADCAT
}
local float2 reduce_lm[SUBGROUP_SIZE * 3];
if (grp == 1) {
reduce_lm[SUBGROUP_SIZE*0 + slid] = total_sum;
}
if (grp == 2) {
reduce_lm[SUBGROUP_SIZE*1 + slid] = total_sum;
}
if (grp == 3) {
reduce_lm[SUBGROUP_SIZE*2 + slid] = total_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (grp == 0) {
total_sum += reduce_lm[SUBGROUP_SIZE*0 + slid];
}
if (grp == 0) {
total_sum += reduce_lm[SUBGROUP_SIZE*1 + slid];
}
if (grp == 0) {
total_sum += reduce_lm[SUBGROUP_SIZE*2 + slid];
}
if (grp == 0) {
dst = (global float*)((global char*)dst + offsetd);
vstore2(total_sum, 0, &(dst[gid * 2]));
}
}

View File

@ -0,0 +1,179 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 4
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_k_f32_l4_lm(
global uchar4 * src0_q,
global uchar * src0_s,
global half * src0_d,
global half * src0_dm,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 64;
int iqs = (idx % 64) * 2;
int n = iqs / 32;
int b = (iqs % 32) / 16;
int is = 2 * n + b;
int qsi = n * 32 + (iqs % 16) * 2;
char * scales = src0_s + ib * 12;
int scidx0 = (is < 4) ? is : (is + 4);
int scidx1 = (is < 4) ? is : (is - 4);
int scidxmask1 = (is < 4) ? 0x30 : 0xC0;
int scidxshift1 = (is < 4) ? 0 : 2;
int mbidx0 = is + 4;
int mbidx1 = (is < 4) ? is + 4 : is;
int mbidxmask0 = (is < 4) ? 0xF : 0xF0;
int mbidxshift0 = (is < 4) ? 0 : 4;
int mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
int mbidxshift1 = (is < 4) ? 0 : 2;
uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1);
uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1);
float d = (float)src0_d[ib] * (float)sc;
float m = -(float)src0_dm[ib] * (float)mbyte;
global uchar4 * qs = src0_q + ib*32 + (qsi >> 2);
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 >> (b * 4))&0x0F, (q.s1 >> (b * 4))&0x0F, (q.s2 >> (b * 4))&0x0F, (q.s3 >> (b * 4))&0x0F)))*d + m;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3;
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}

View File

@ -0,0 +1,196 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define BLOCK_Q4K_SIZE 144
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 16
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32_flat(
global uchar * src0_q,
global uchar * src0_s,
global half * src0_d,
global half * src0_dm,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8;
int it = get_sub_group_local_id()%8;
int iq = it/4;
int ir = it%4;
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q4K_SIZE;
uint blk = nb01 / BLOCK_Q4K_SIZE;
global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2);
global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE;
global half * blk_d = (global half *)src0_d + offset_src0;
global half * blk_dm = (global half *)src0_dm + offset_src0;
int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13;
global float * y = (global float *)(src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir);
global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq;
global half * d = blk_d + ib;
global half * dm = blk_dm + ib;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = *d;
float dmin = *dm;
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += blk*64;
sc += blk*6;
d += blk;
dm += blk;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}

View File

@ -97,6 +97,8 @@ struct ggml_backend_openvino_buffer_context {
ov_buffer = std::make_shared<ov::intel_gpu::ocl::USMTensor>(std::move(usm_tensor));
} else {
data = ggml_aligned_malloc(size);
GGML_ASSERT(data);
memset(data, 0, size);
ov_buffer = std::make_shared<ov::Tensor>(ov::element::u8, ov::Shape{size}, data);
}

View File

@ -1162,12 +1162,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
return nullptr;
}
// Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types)
if (ggml_blck_size((enum ggml_type)tensor->type) == 0) {
GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type);
return nullptr;
}
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
if (result == nullptr) {
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type);
return nullptr;
}
@ -1437,7 +1443,9 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
const rpc_tensor * tensor = it_ptr->second;
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
if (result == nullptr) {
if (result == nullptr || result->buffer == nullptr) {
GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n",
__func__, result == nullptr ? "tensor" : "buffer", id);
return nullptr;
}
tensor_map[id] = result;

View File

@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (a->ne[3] != b->ne[3]) {
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}
}
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16 ) {
// TODO: support GGML_TYPE_BF16
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
// TODO: The configuration below needs more work to be supported with oneDNN
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&

View File

@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
const bool use_subgroup_reduce = device->subgroup_arithmetic;
for (uint32_t si = 0; si < 3; si++) {
const uint32_t S_V = gdn_sizes[si];
GGML_ASSERT(is_pow2(S_V));
uint32_t lanes_per_column;
if (S_V >= 128u && device->subgroup_clustered) {
lanes_per_column = 8u;
} else {
// Use largest power-of-two that divides both S_V and subgroup_size so that
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
// This means we don't need extra bounds checking logic in the shader.
lanes_per_column = std::min(S_V, device->subgroup_size);
}
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
size_t gdn_len;
const void * gdn_data;
if (use_subgroup_reduce && need_clustered_shader) {
gdn_len = gated_delta_net_f32_len;
gdn_data = (const void *)gated_delta_net_f32_data;
} else if (use_subgroup_reduce) {
gdn_len = gated_delta_net_f32_nocluster_len;
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
} else {
gdn_len = gated_delta_net_f32_shmem_len;
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
}
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
}
}
}
@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
pc, { H, n_seqs, S_v });
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@ -16018,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
case 0xE20C: // B570
return 18;
case 0xE20B: // B580
case 0xE211: // Pro B60
return 20;
default:
return 0;

View File

@ -1,11 +1,25 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
shared FLOAT_TYPE s_k[S_V];
shared FLOAT_TYPE s_q[S_V];
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
const uint lane = gl_SubgroupInvocationID;
temp[lane] = partial;
barrier();
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
FLOAT_TYPE other = temp[lane ^ s];
barrier();
temp[lane] += other;
barrier();
}
const FLOAT_TYPE result = temp[lane];
barrier();
return result;
}
#endif
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
switch (LANES_PER_COLUMN) {
case 1u:
return partial;
#if USE_SUBGROUP_CLUSTERED
// Workaround for GLSL requiring a literal constant for the cluster size.
// The branches should all fold away.
case 2u:
return subgroupClusteredAdd(partial, 2u);
case 4u:
return subgroupClusteredAdd(partial, 4u);
case 8u:
return subgroupClusteredAdd(partial, 8u);
case 16u:
return subgroupClusteredAdd(partial, 16u);
case 32u:
return subgroupClusteredAdd(partial, 32u);
case 64u:
return subgroupClusteredAdd(partial, 64u);
#endif
default:
#if USE_SUBGROUP_ADD
return subgroupAdd(partial);
#else
return reduce_add_shmem(partial);
#endif
}
}
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
@ -42,9 +103,9 @@ void main() {
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@ -53,76 +114,56 @@ void main() {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
if (KDA != 0) {
const uint g_base = gb_off * S_V;
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
}
barrier();
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
FLOAT_TYPE k_reg[ROWS_PER_LANE];
FLOAT_TYPE q_reg[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
}
FLOAT_TYPE g_exp[ROWS_PER_LANE];
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
kv_col += dot(
vec4(state[i], state[i+1], state[i+2], state[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
g_exp[r] = g_val;
}
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = g_val * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
} else {
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
kv_col += dot(gv * sv, kv);
const uint g_base = gb_off * S_V;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
}
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = gv * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
FLOAT_TYPE kv_shard = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
}
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_partial = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
if (lane == 0) {
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
barrier();
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + col * S_V + i] = state[i];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}

View File

@ -987,7 +987,9 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@ -301,6 +301,8 @@ class Keys:
IMAGE_SIZE = "clip.vision.image_size"
IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels"
IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels"
PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles"
PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@ -3869,6 +3871,8 @@ class LlamaFileType(IntEnum):
# MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
MOSTLY_TQ1_0 = 36 # except 1d tensors
MOSTLY_TQ2_0 = 37 # except 1d tensors
MOSTLY_MXFP4_MOE = 38 # except 1d tensors
MOSTLY_NVFP4 = 39 # except 1d tensors
GUESSED = 1024 # not specified in the model file

View File

@ -1156,6 +1156,12 @@ class GGUFWriter:
def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
def add_vision_preproc_max_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value)
def add_vision_preproc_min_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value)
def add_vision_preproc_image_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
@ -1300,7 +1306,7 @@ class GGUFWriter:
else:
raise ValueError("Invalid GGUF metadata value type or value")
return kv_data
return bytes(kv_data)
@staticmethod
def format_n_bytes_to_str(num: int) -> str:

View File

@ -138,7 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
if isinstance(meta_noop, tuple):
dtype, shape = meta_noop
assert callable(shape)
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) # ty: ignore[call-top-callable]
else:
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)

View File

@ -91,11 +91,11 @@ class __Quant(ABC):
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
cls.qtype = qtype
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__quantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__quantize_array,
meta_noop=(np.uint8, cls.__shape_to_bytes)
)
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__dequantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__dequantize_array,
meta_noop=(np.float32, cls.__shape_from_bytes)
)

View File

@ -11,33 +11,33 @@ from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVa
try:
from sentencepiece import SentencePieceProcessor
except ImportError:
SentencePieceProcessor = None
SentencePieceProcessor: Any = None
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
_filter_valid_tokenizer_files,
)
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer,
)
except ImportError:
_mistral_common_installed = False
MistralTokenizer = None
Tekkenizer = None
SentencePieceTokenizer = None
_filter_valid_tokenizer_files = None
MistralTokenizer: Any = None
Tekkenizer: Any = None
SentencePieceTokenizer: Any = None
_filter_valid_tokenizer_files: Any = None
else:
_mistral_common_installed = True
try:
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
get_one_valid_tokenizer_file,
)
except ImportError:
# We still want the conversion to work with older mistral-common versions.
get_one_valid_tokenizer_file = None
get_one_valid_tokenizer_file: Any = None
import gguf
@ -703,7 +703,7 @@ class MistralVocab(Vocab):
tokenizer_file_path = base_path / tokenizer_file
self.tokenizer = MistralTokenizer.from_file(
self.tokenizer: Any = MistralTokenizer.from_file(
tokenizer_file_path
).instruct_tokenizer.tokenizer
self.tokenizer_type = (

View File

@ -0,0 +1,61 @@
{#- Copyright 2025-present the Unsloth team. All rights reserved. #}
{#- Licensed under the Apache License, Version 2.0 (the "License") #}
{#- Edits made by Unsloth to make it work for most inference engines #}
{# ───── defaults ───── #}
{%- if enable_thinking is not defined -%}
{%- set enable_thinking = true -%}
{%- endif -%}
{# ───── reasoning mode ───── #}
{%- if enable_thinking -%}
{%- set reasoning_mode = "/think" -%}
{%- else -%}
{%- set reasoning_mode = "/no_think" -%}
{%- endif -%}
{# ───── header (system message) ───── #}
{{- "<|im_start|>system\n" -}}
{%- if messages[0].role == "system" -%}
{%- set system_message = messages[0].content -%}
{%- if "/no_think" in system_message -%}
{%- set reasoning_mode = "/no_think" -%}
{%- elif "/think" in system_message -%}
{%- set reasoning_mode = "/think" -%}
{%- endif -%}
{%- set custom_instructions = system_message.replace("/no_think", "") -%}
{%- set custom_instructions = custom_instructions.replace("/think", "") -%}
{%- set custom_instructions = custom_instructions.rstrip() -%}
{%- endif -%}
{{- "## Metadata\n\n" -}}
{{- "Knowledge Cutoff Date: June 2025\n" -}}
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
{{- "## Custom Instructions\n\n" -}}
{%- if custom_instructions -%}
{{- custom_instructions + "\n\n" -}}
{%- elif reasoning_mode == "/think" -%}
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
{%- else -%}
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
{%- endif -%}
{{- "<|im_end|>\n" -}}
{# ───── main loop ───── #}
{%- for message in messages -%}
{%- set content = message.content if message.content is string else "" -%}
{%- if message.role == "user" -%}
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
{%- elif message.role == "assistant" -%}
{%- if reasoning_mode == "/think" -%}
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
{%- else -%}
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
{%- endif -%}
{%- elif message.role == "tool" -%}
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
{%- endif -%}
{%- endfor -%}
{# ───── generation prompt ───── #}
{%- if add_generation_prompt -%}
{%- if reasoning_mode == "/think" -%}
{{ "<|im_start|>assistant\n" }}
{%- else -%}
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
{%- endif -%}
{%- endif -%}

View File

@ -1,5 +1,5 @@
{
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"],
"extraPaths": ["gguf-py", "examples/model-conversion/scripts", "examples/model-conversion/scripts/utils"],
"pythonVersion": "3.9",
"pythonPlatform": "All",
"reportUnusedImport": "warning",

View File

@ -684,6 +684,7 @@ else:
sys.exit(1)
assert isinstance(hexsha8_baseline, str)
name_baseline = bench_data.get_commit_name(hexsha8_baseline)
hexsha8_compare = name_compare = None
@ -717,6 +718,7 @@ else:
parser.print_help()
sys.exit(1)
assert isinstance(hexsha8_compare, str)
name_compare = bench_data.get_commit_name(hexsha8_compare)
# Get tool-specific configuration

View File

@ -241,10 +241,10 @@ class CodeEditor(QPlainTextEdit):
if not self.isReadOnly():
selection = QTextEdit.ExtraSelection()
line_color = QColorConstants.Yellow.lighter(160)
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue]
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue]
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue]
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue]
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra_selections.append(selection)
self.setExtraSelections(extra_selections)
@ -262,8 +262,8 @@ class CodeEditor(QPlainTextEdit):
)
extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra])
@ -274,8 +274,8 @@ class CodeEditor(QPlainTextEdit):
cursor.select(QTextCursor.SelectionType.LineUnderCursor)
extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra])
@ -395,8 +395,8 @@ class JinjaTester(QMainWindow):
ensure_ascii=ensure_ascii,
)
)
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) # ty: ignore[invalid-assignment]
env.globals["raise_exception"] = raise_exception # ty: ignore[invalid-assignment]
try:
template = env.from_string(template_str)
output = template.render(context)

View File

@ -189,6 +189,7 @@ def benchmark(
data: list[dict] = []
assert isinstance(prompts, list)
for i, p in enumerate(prompts):
if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1)

View File

@ -48,5 +48,5 @@ adb $adbserial $adbhost shell " \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$ndev $nhvx $opmask $verbose $experimental $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--batch-size 128 -ngl 99 $cli_opts $@ \
--ubatch-size 256 -fa 1 -ngl 99 $cli_opts $@ \
"

Some files were not shown because too many files have changed in this diff Show More