Merge branch 'master' into xsn/server_tools

This commit is contained in:
Xuan Son Nguyen 2026-03-21 00:40:59 +01:00
commit 3c5dac1ddf
128 changed files with 21291 additions and 26666 deletions

87
.github/workflows/ai-issues.yml vendored Normal file
View File

@ -0,0 +1,87 @@
name: AI review (issues)
on:
issues:
types: [opened]
jobs:
find-related:
if: github.event.action == 'opened'
runs-on: [self-hosted, opencode]
permissions:
contents: read
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 1
- name: Find related
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENCODE_PERMISSION: |
{
"bash": {
"*": "deny",
"gh issue*": "allow",
"gh search issues*": "allow"
},
"webfetch": "deny"
}
run: |
rm AGENTS.md
rm CLAUDE.md
timeout 5m opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
Issue number: ${{ github.event.issue.number }}
Lookup the contents of the issue using the following 'gh' command:
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
Next, perform the following task and then post a SINGLE comment (if needed).
---
TASK : FIND RELATED ISSUES
Using the 'gh' CLI tool, search through existing issues on Github.
Find related or similar issues to the newly created one and list them.
Do not list the new issue itself (it is #${{ github.event.issue.number }}).
Consider:
1. Similar titles or descriptions
2. Same error messages or symptoms
3. Related functionality or components
4. Similar feature requests
---
POSTING YOUR COMMENT:
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
- If no related issues were found, do NOT comment at all.
- If related issues were found, include a section listing them with links using the following format:
[comment]
This issue might be similar or related to the following issue(s):
- #[related_issue_number]: [brief description of how they are related]
- #[related_issue_number]: [brief description of how they are related]
...
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
[/comment]
Remember:
- Do not include the comment tags in your actual comment.
- Post at most ONE comment combining all findings.
- If you didn't find issues that are related enough, post nothing.
- You have access only to the 'gh' CLI tool - don't try to use other tools.
- If the output from a tool call is too long, try to limit down the search.
"

80
.github/workflows/hip-quality-check.yml vendored Normal file
View File

@ -0,0 +1,80 @@
name: HIP quality check
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-22-hip-quality-check:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:7.2
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: ubuntu-22-hip-quality-check
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with Werror
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=Off \
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc)
- name: Check for major VGPR spills
id: vgpr_check
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=On \
-DCMAKE_HIP_FLAGS="" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log

View File

@ -178,6 +178,8 @@ Maintainers reserve the right to decline review or close pull requests for any r
- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces.
_(NOTE: for legacy reasons, existing code is not required to follow this guideline)_
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
# Documentation
- Documentation is a community effort

View File

@ -1830,23 +1830,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
"BNF-like grammar to constrain generations (see samples in grammars/ dir)",
[](common_params & params, const std::string & value) {
params.sampling.grammar = value;
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
}
).set_sparam());
add_opt(common_arg(
{"--grammar-file"}, "FNAME",
"file to read grammar from",
[](common_params & params, const std::string & value) {
params.sampling.grammar = read_file(value);
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
}
).set_sparam());
add_opt(common_arg(
{"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) {
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
}
).set_sparam());
add_opt(common_arg(
@ -1863,7 +1863,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::istreambuf_iterator<char>(),
std::back_inserter(schema)
);
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
}
).set_sparam());
add_opt(common_arg(
@ -3502,7 +3502,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
{"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),

View File

@ -1,3 +1,4 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
namespace autoparser {
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
p(p),
inputs(inputs),
reasoning_parser(p.eps()) {}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs) {
const struct generation_params & inputs) {
// Run differential analysis to extract template structure
struct autoparser autoparser;
autoparser.analyze_template(tmpl);
@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const struct generation_params & inputs,
const autoparser & autoparser) {
// Build the parser using the analysis results
auto parser = autoparser.build_parser(inputs);
// Create the result structure
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens;
data.parser = parser.save();
auto parser = autoparser.build_parser(inputs);
data.parser = parser.save();
// Build grammar if tools are present
bool has_tools =
@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
return data;
}
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
if (!analysis_complete) {
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
}
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
// If the template uses Python dict format (single-quoted strings in JSON structures),
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs);
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = inputs.enable_thinking;
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content;
// Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx);
auto parser = p.eps();
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
if (has_response_format) {
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return ctx.reasoning_parser + p.space() + p.choice({
parser = ctx.reasoning_parser + p.space() + p.choice({
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
response_format
}) + p.end();
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
parser = tools.build_parser(ctx);
} else {
parser = content.build_parser(ctx);
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
return tools.build_parser(ctx);
}
return content.build_parser(ctx);
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
return parser;
});
}
@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
return p.eps();
}
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
// Both use the same tag-based pattern if markers are available
if (!start.empty() && !end.empty()) {
return p.optional(start + p.reasoning(p.until(end)) + end);
if (!end.empty()) {
if (!start.empty()) {
// Standard tag-based: optional(<think>reasoning</think>)
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
}
// Delimiter-style (empty start)
return p.optional(p.reasoning(p.until(end)) + end + p.space());
}
} else if (mode == reasoning_mode::DELIMITER) {
return p.optional(p.reasoning(p.until(end)) + end);
}
return p.eps();
@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section) + p.space() + args_seq;
matched_atomic = true;
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
// Only peek for an arg tag when there are required args that must follow.
// When all args are optional, the model may emit no arg tags at all (#20650).
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
matched_atomic = true;

View File

@ -1,9 +1,11 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include "peg-parser.h"
#include <cctype>
#include <numeric>
@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
result.suffix = "";
// pick prefix = all as representation
}
// When left has no unique content (result.left is empty), left is entirely
// shared with right. The simultaneous prefix/suffix segment matching can
// incorrectly consume trailing segments of left as suffix when those same
// segments also appear at the end of right (e.g. "\n" at the end of both
// the shared content and the generation prompt). This rotates the diff.
// Fix: if left is a prefix of right, enforce that directly.
if (result.left.empty() && !result.right.empty() &&
left.size() <= right.size() &&
right.substr(0, left.size()) == left) {
result.prefix = left;
result.suffix = "";
result.right = right.substr(left.size());
}
return result;
}
@ -291,10 +308,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
return result;
}
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start) {
auto parser = prs;
if (!inputs.generation_prompt.empty()) {
size_t end_pos = inputs.generation_prompt.size();
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
end_pos = inputs.generation_prompt.find(reasoning_start);
}
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
parser = p.literal(cut_genprompt) + parser;
}
return parser;
}
namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
templates_params tmpl_params;
generation_params tmpl_params;
tmpl_params.messages = params.messages;
tmpl_params.tools = params.tools;
tmpl_params.add_generation_prompt = params.add_generation_prompt;

View File

@ -1,6 +1,7 @@
#pragma once
#include "chat-auto-parser.h"
#include "peg-parser.h"
#include <functional>
#include <optional>
#include <string>
@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
// Wrap parser with generation prompt parser
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start = {});
namespace autoparser {
// Apply a template with the given parameters, returning the rendered string (empty on failure)

View File

@ -50,7 +50,7 @@ namespace autoparser {
// High-level params for parser generation
// ============================================================================
struct templates_params {
struct generation_params {
json messages;
json tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
@ -62,6 +62,7 @@ struct templates_params {
bool add_generation_prompt = false;
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::string generation_prompt;
json extra_context;
bool add_bos = false;
bool add_eos = false;
@ -77,11 +78,7 @@ struct templates_params {
// Reasoning handling mode (derived from R1-R3 comparisons)
enum class reasoning_mode {
NONE, // No reasoning markers detected
TAG_BASED, // Standard tag-based: <think>...</think>
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
// with both opened and closed tag for disabled thinking
TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
TOOLS_ONLY // Only reason on tool calls, not on normal content
};
@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
return os << "NONE";
case reasoning_mode::TAG_BASED:
return os << "TAG_BASED";
case reasoning_mode::DELIMITER:
return os << "DELIMITER";
case reasoning_mode::FORCED_OPEN:
return os << "FORCED_OPEN";
case reasoning_mode::FORCED_CLOSED:
return os << "FORCED_CLOSED";
case reasoning_mode::TOOLS_ONLY:
return os << "TOOLS_ONLY";
default:
@ -184,7 +175,6 @@ struct tool_format_analysis {
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
std::string function_field = "function";
std::string name_field = "name";
@ -225,12 +215,12 @@ struct analyze_content;
struct parser_build_context {
common_chat_peg_builder & p;
const templates_params & inputs;
const generation_params & inputs;
common_peg_parser reasoning_parser;
bool extracting_reasoning = false;
const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
};
// ============================================================================
@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
analyze_reasoning() = default;
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
common_peg_parser build_parser(parser_build_context & ctx) const override;
@ -381,7 +372,7 @@ struct autoparser {
void analyze_template(const common_chat_template & tmpl);
// Build the PEG parser for this template
common_peg_arena build_parser(const templates_params & inputs) const;
common_peg_arena build_parser(const generation_params & inputs) const;
private:
// Collect tokens from entire analysis to preserve
@ -395,10 +386,10 @@ struct autoparser {
class peg_generator {
public:
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs);
const struct generation_params & inputs);
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const struct generation_params & inputs,
const autoparser & autoparser);
};

View File

@ -2,6 +2,7 @@
#include "chat-auto-parser-helpers.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "common.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include "peg-parser.h"
@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
tmpl.src.find("reasoning_content") == std::string::npos &&
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
analysis.reasoning.mode == reasoning_mode::NONE) {
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN;
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<think>";
analysis.reasoning.end = "</think>";
analysis.preserved_tokens.push_back("<think>");
@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
}
if (result.result.success()) {
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close
mode = reasoning_mode::TAG_BASED;
} else {
mode = reasoning_mode::FORCED_CLOSED;
}
mode = reasoning_mode::TAG_BASED;
start = trim_whitespace(result.tags["pre"]);
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else if (!result.tags["post"].empty()) {
mode = reasoning_mode::DELIMITER;
end = result.tags["post"];
mode = reasoning_mode::TAG_BASED;
end = trim_trailing_whitespace(result.tags["post"]);
}
}
}
@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
const auto & diff = comparison->diff;
std::string left_trimmed = trim_whitespace(diff.left);
std::string right_trimmed = trim_whitespace(diff.right);
if (left_trimmed.empty() && !diff.right.empty()) {
std::string right_trimmed = trim_whitespace(diff.right);
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
if (start.empty()) {
start = right_trimmed;
mode = reasoning_mode::FORCED_OPEN;
mode = reasoning_mode::TAG_BASED;
}
}
} else if (right_trimmed.empty() && !diff.left.empty()) {
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
if (end.empty()) {
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
start = seg[seg.size() - 2].value;
}
end = left_trimmed;
mode = reasoning_mode::TAG_BASED;
}
}
}
if (start.empty() && !end.empty()) {
mode = reasoning_mode::DELIMITER;
}
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
// but enable_thinking=true produces only the start marker
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(start) + p.space() + p.literal(end) + p.rest();
});
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
});
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
mode = reasoning_mode::FORCED_CLOSED;
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
if (result.result.success()) {
start = result.tags["pre"];
mode = reasoning_mode::FORCED_CLOSED;
}
}
}
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
if (!diff.left.empty() && !diff.right.empty()) {
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
if (seg_A.size() == 1 && seg_B.size() == 1) {
mode = reasoning_mode::FORCED_CLOSED;
start = seg_B[0].value;
end = seg_A[0].value;
}
}
if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
mode = reasoning_mode::TAG_BASED;
}
}
@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() {
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) {
start = result.tags["pre"];
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else {
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
});
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) {
end = result.tags["post"];
end = trim_trailing_whitespace(result.tags["post"]);
} else {
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
mode = reasoning_mode::NONE;
}
}
@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
return;
}
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES };
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
});
auto result = parser.parse_anywhere_and_extract(haystack);
if (!result.result.success()) {
return json_quote_style::NONE;
}
return result.tags.count("sq") && !result.tags["sq"].empty()
? json_quote_style::SINGLE_QUOTES
: json_quote_style::DOUBLE_QUOTES;
return result.result.success();
};
auto fun_quote = in_json_haystack(fun_name_needle);
auto arg_quote = in_json_haystack(arg_name_needle);
if (fun_quote != json_quote_style::NONE) {
if (fun_quote) {
// no need to check further, we're in JSON land
format.mode = tool_format::JSON_NATIVE;
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES);
} else if (arg_quote != json_quote_style::NONE) {
} else if (arg_quote) {
format.mode = tool_format::TAG_WITH_JSON;
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
} else {
format.mode = tool_format::TAG_WITH_TAGGED;
}

View File

@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset();
}
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
if (!result.reasoning_content.empty()) {
bool all_whitespace = true;
for (char c : result.reasoning_content) {
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
all_whitespace = false;
break;
}
}
if (all_whitespace) {
result.reasoning_content.clear();
}
}
}
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {

View File

@ -1,5 +1,6 @@
#include "chat.h"
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "common.h"
@ -22,6 +23,7 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using json = nlohmann::ordered_json;
@ -760,7 +762,7 @@ static void foreach_parameter(const json &
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override,
const std::optional<json> & tools_override,
const std::optional<json> & additional_context) {
@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
}
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
@ -876,8 +878,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
<< "```";
return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
inputs, "[THINK]");
}
// Tool call parser
@ -897,12 +899,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
inputs, "[THINK]");
}
// Content only parser
include_grammar = false;
return reasoning << p.content(p.rest());
return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
});
data.parser = parser.save();
@ -928,7 +931,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
@ -936,7 +939,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
for (auto msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
msg["thinking"] = msg.at("reasoning_content");
msg.erase("content");
if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
msg.erase("content");
}
}
adjusted_messages.push_back(msg);
}
@ -986,7 +991,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format);
return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
inputs, "<|channel|>");
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
@ -1018,10 +1024,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
}
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg));
return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
inputs, "<|channel|>");
}
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg);
return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
inputs, "<|channel|>");
});
data.parser = parser.save();
@ -1049,7 +1057,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1070,13 +1078,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Build content parser for >>>all\n{content}
// When tools are present, content stops before the next ">>>" (tool call)
// When no tools, content goes until end
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest());
auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal("all\n") + p.content(p.rest());
// If no tools or tool_choice is NONE, just parse content
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
// When no tools, just match the prefix and capture everything after
return content_until_end + p.end();
return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
}
// Build tool call parsers for each available function
@ -1088,7 +1096,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Tool format: >>>function_name\n{json_args}
auto tool_parser = p.tool(
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
);
@ -1099,17 +1107,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
auto content_and_tools = content_until_tool + tools_only;
auto ret = p.eps();
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, tools_only }) + p.end();
ret = p.choice({ content_and_tools, tools_only }) + p.end();
} else {
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
}
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
} else if (inputs.parallel_tool_calls) {
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
} else {
auto content_and_tool = content_until_tool + tool_choice;
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
}
if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
}
auto content_and_tool = content_until_tool + tool_choice;
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
return wrap_for_generation_prompt(p, ret, inputs);
});
data.parser = parser.save();
@ -1139,14 +1150,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
// The ID contains both the function name and an incrementing counter
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
@ -1161,6 +1170,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Kimi K2 Thinking format:
// - Reasoning: <think>{reasoning}</think>
@ -1172,16 +1193,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// <|tool_calls_section_end|>
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
// Tool call markers
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
// Tool call markers
auto end = p.end();
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
@ -1193,7 +1205,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// Content only parser (no tools)
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end;
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
inputs, THINK_START);
}
// Build tool call parsers for each available function
@ -1229,7 +1242,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
return reasoning + content_before_tools + tool_calls + end;
return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
inputs, THINK_START);
});
data.parser = parser.save();
@ -1259,7 +1273,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
// Tool calls can appear multiple times (parallel tool calls)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1278,13 +1292,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string TOOL_CALL_START = "<|tool_call_start|>";
const std::string TOOL_CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto end = p.end();
auto reasoning = p.eps();
@ -1293,7 +1309,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end;
return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
THINK_START);
}
auto tool_calls = p.rule("tool-calls",
@ -1305,7 +1322,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto content = p.content(p.until(TOOL_CALL_START));
return reasoning + content + tool_calls + end;
return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
THINK_START);
});
data.parser = parser.save();
@ -1331,7 +1349,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
static common_chat_params common_chat_params_init_gigachat_v3(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
const autoparser::generation_params & inputs) {
common_chat_params data;
@ -1345,9 +1363,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto ret = p.eps();
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// Build a choice of all available tools
auto tool_choice = p.choice();
@ -1370,13 +1389,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
} else {
// Content only parser
include_grammar = false;
ret = p.content(p.rest());
}
// Content only parser
include_grammar = false;
return p.content(p.rest());
return wrap_for_generation_prompt(p, ret, inputs);
});
data.parser = parser.save();
@ -1471,87 +1491,10 @@ static json common_chat_extra_context() {
return ctx;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::templates_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
// params.parallel_tool_calls = false;
// } else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
//}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
return p.content(p.rest());
});
data.parser = parser.save();
return data;
}
static std::optional<common_chat_params> try_specialized_template(
const common_chat_template & tmpl,
const std::string & src,
const autoparser::generation_params & params) {
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
@ -1592,14 +1535,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
// GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos &&
src.find("<|function_call|>") == std::string::npos
) {
src.find("<|function_call|>") == std::string::npos) {
LOG_DBG("Using specialized template: GigaChatV3\n");
return common_chat_params_init_gigachat_v3(tmpl, params);
}
return std::nullopt;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::generation_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl =
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;
params.add_generation_prompt = inputs.add_generation_prompt;
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
params.parallel_tool_calls = inputs.parallel_tool_calls;
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
});
data.parser = parser.save();
return data;
}
if (auto result = try_specialized_template(tmpl, src, params)) {
result->generation_prompt = params.generation_prompt;
return *result;
}
try {
LOG_DBG("Using differential autoparser\n");
LOG_DBG("%s: using differential autoparser\n", __func__);
struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
@ -1607,13 +1641,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
auto_params.generation_prompt = params.generation_prompt;
common_peg_arena arena;
arena.load(auto_params.parser);
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@ -1711,14 +1743,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
LOG_DBG("No parser definition detected, assuming pure content parser.");
}
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
const std::string effective_input = params.generation_prompt.empty()
? input
: params.generation_prompt + input;
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
if (params.debug) {
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
}
common_peg_parse_context ctx(input, flags);
common_peg_parse_context ctx(effective_input, flags);
auto result = parser.parse(ctx);
if (result.fail()) {
@ -1738,7 +1774,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
return msg;
}
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
input.substr(result.end));
effective_input.substr(result.end));
}
common_chat_msg msg;

View File

@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
struct common_chat_templates;
namespace autoparser {
struct templates_params;
struct generation_params;
} // namespace autoparser
struct common_chat_tool_call {
@ -212,7 +212,7 @@ struct common_chat_params {
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
std::string generation_prompt;
bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
@ -229,14 +229,14 @@ struct common_chat_parser_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
std::string generation_prompt;
bool parse_tool_calls = true;
bool debug = false; // Enable debug output for PEG parser
common_peg_arena parser = {};
common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open;
format = chat_params.format;
generation_prompt = chat_params.generation_prompt;
}
};
@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt);

View File

@ -3,12 +3,14 @@
#pragma once
#include "ggml-opt.h"
#include "ggml.h"
#include "llama-cpp.h"
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
#include <map>
@ -178,6 +180,43 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
// Grammar type enumeration
enum common_grammar_type {
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
};
// Grammar variant struct with type and grammar string
struct common_grammar {
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
std::string grammar;
// Default constructor - no grammar
common_grammar() = default;
// Constructor with type and grammar string
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
}
// Check if a grammar is set
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};
// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
return g.grammar;
}
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -228,7 +267,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE,
};
std::string grammar; // optional BNF-like grammar to constrain sampling
common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
@ -236,10 +275,15 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// The assistant generation prompt already prefilled into the prompt.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
// Only applied when the grammar is of output-format or tool-calls type.
std::string generation_prompt;
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)

View File

@ -451,7 +451,7 @@ struct value_array_t : public value_t {
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence());
}
};
using value_array = std::shared_ptr<value_array_t>;
@ -587,7 +587,7 @@ struct value_object_t : public value_t {
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence());
}
};
using value_object = std::shared_ptr<value_object_t>;

View File

@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->force_pos = 0;
}
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
return common_reasoning_budget_init_state(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
}
);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens) {
// Determine initial state from prefill: COUNTING if the prefill begins with
// the start sequence but does not also contain the end sequence after it.
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
if (!prefill_tokens.empty() && !start_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() &&
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
initial_state = REASONING_BUDGET_COUNTING;
// If the end sequence also follows the start in the prefill, reasoning
// was opened and immediately closed — stay IDLE.
if (!end_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
initial_state = REASONING_BUDGET_IDLE;
}
}
}
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}

View File

@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// prefill_tokens - tokens already present in the prompt (generation prompt);
// used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,

View File

@ -1,13 +1,16 @@
#include "sampling.h"
#include "common.h"
#include "ggml.h"
#include "log.h"
#include "reasoning-budget.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstring>
#include <unordered_map>
#include <vector>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
const std::string & grammar_str = common_grammar_value(params.grammar);
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
trigger_patterns_c.push_back(regex.c_str());
}
if (!params.grammar.empty()) {
if (!grammar_str.empty()) {
if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size());
} else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
}
}
}
// Feed generation prompt tokens to the grammar sampler so it advances past
// tokens the template already placed in the prompt.
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
std::vector<llama_token> prefill_tokens;
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
GGML_ASSERT(vocab != nullptr);
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
if (!prefill_tokens.empty()) {
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
// Some tokenizers will add a space before the first special token, need to remove
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
}
}
if (grmr) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
}
} catch (std::exception &e) {
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
throw e;
}
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
prefill_tokens));
}
if (params.has_logit_bias()) {

View File

@ -1062,6 +1062,10 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")
if self.hparams.get("is_causal") is False:
self.gguf_writer.add_causal_attention(False)
logger.info("gguf: causal attention = False")
# TODO: Handle "sliding_attention" similarly when models start implementing it
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
if (rope_type := rope_params.get("rope_type")) is not None:

View File

@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
**Analysis + Parser Building in Two Steps**:
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
## Data Structures
@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
### `analyze_tools` and its sub-structs
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
| Value | Description |
|-----------------|-----------------------------------------------------------------------------------|
| `NONE` | No reasoning markers detected |
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
| `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
**`content_mode`**: How the template wraps assistant content.
| Value | Description |
@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
- Searches `diff.right` (output with reasoning) for the reasoning content needle
- Uses PEG parsers to find surrounding markers:
- If both pre/post markers found in `diff.right``TAG_BASED` (both tags visible in diff = no forced close)
- If both found but post marker only in the full output B → `FORCED_CLOSED`
- If only post marker found → `DELIMITER`
- If both pre/post markers found in `diff.right``TAG_BASED`
- If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
- If only post marker found → `TAG_BASED` (delimiter-style, empty start)
- Sets `reasoning.start` and `reasoning.end`
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
- Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker → sets `reasoning.start`, mode = `TAG_BASED`
- Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
- The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
@ -343,7 +352,7 @@ Classification logic:
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
#### Reasoning Parser (`analyze_reasoning::build_parser`)
| Mode | Parser |
|-----------------------------------|---------------------------------------------------------------------|
| Not extracting reasoning | `eps()` |
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
| Mode | Parser |
|-----------------------------------------------|---------------------------------------------------------------------------|
| Not extracting reasoning | `eps()` |
| `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
| `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
#### Content Parser (`analyze_content::build_parser`)
@ -410,9 +420,7 @@ All three tool parsers return:
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
```
### Python Dict Format
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
## Mapper
@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
- **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
## Files
| File | Purpose |
|-------------------------------------------|----------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
| | `compare_variants()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool |
| File | Purpose |
|-------------------------------------------|---------------------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
| | `wrap_for_generation_prompt()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool |
## Testing & Debugging
@ -516,10 +524,10 @@ To support a new template format:
## Edge Cases and Quirks
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.

View File

@ -28,6 +28,9 @@ Additionally, there the following images, similar to the above:
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).

View File

@ -12,9 +12,9 @@ Legend:
- 🟡 Partially supported by this backend
- ❌ Not supported by this backend
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
| Operation | BLAS | CANN | CPU | CUDA | MTL | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -23,63 +23,63 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -91,31 +91,31 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
end = 2 * ((n_head - 1) - n_head_log2) + 1;
step = 2;
count = n_head - n_head_log2;
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,
dtype);
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
step, dtype);
}
}
@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16);
switch (src0->type) {
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if (src0->type == dst->type) {
@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
ggml_cann_mat_mul_fp(ctx, dst);
break;
case GGML_TYPE_Q4_0:
@ -2943,6 +2949,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// Rotate full tensor (no tail), using trans tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
} else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
// In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
// read and write the same non-contiguous buffer. Use contiguous temporaries.
size_t contiguous_nb[GGML_MAX_DIMS];
contiguous_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
}
int64_t total_elements = ggml_nelements(src0);
ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
src0->ne, contiguous_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
dst->ne, contiguous_nb, GGML_MAX_DIMS);
cann_copy(ctx, acl_src.get(), acl_src_contig.get());
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
} else {
// Rotate full tensor (no tail), using original tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
@ -3599,6 +3626,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
// Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
// (required by FusedInferAttentionScoreV2)
const int64_t D = src0->ne[0];
const int64_t D_padded = GGML_PAD(D, 16);
const bool needs_padding = (D != D_padded);
ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
if (needs_padding) {
int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
ggml_cann_pool_alloc & allocator) {
int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
size_t pad_nb[GGML_MAX_DIMS];
pad_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
}
int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
void * buffer = allocator.alloc(nelements * faElemSize);
acl_tensor_ptr padded =
ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
tensor = std::move(padded);
};
pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
src0_bsnd_ne[0] = D_padded;
src1_bsnd_ne[0] = D_padded;
src2_bsnd_ne[0] = D_padded;
}
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
acl_tensor_ptr bcast_pse_tensor;
@ -3688,17 +3753,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
acl_tensor_ptr fa_dst_tensor;
acl_tensor_ptr acl_dst_tensor;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
if (dst->type == GGML_TYPE_F32 || needs_padding) {
int64_t * out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
fa_dst_tensor =
ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@ -3730,8 +3794,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
nullptr // softmaxLse
);
if (dst->type == GGML_TYPE_F32) {
// Step 6: post-processing, permute and cast to f32
// Step 6: post-processing — slice padded output and/or cast to f32
if (needs_padding) {
ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
size_t sliced_nb[GGML_MAX_DIMS];
sliced_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
}
int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
sliced_ne, sliced_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
} else {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
}
} else if (dst->type == GGML_TYPE_F32) {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
}

View File

@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset, ctx->device);
@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
} else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
} else if (weight_to_nz && tensor->type != GGML_TYPE_BF16
&& is_matmul_weight((const ggml_tensor *) tensor)) {
// NZ format weight are not support quantized yet.
// If ND tensor transform to NZ, size may changed.
int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MUL_MAT:
{
switch (op->src[0]->type) {
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
case GGML_TYPE_Q8_0:
return true;
default:
@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;
@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_CPY:
{
ggml_tensor * src = op->src[0];
#ifdef ASCEND_310P
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
// only support F32 and F16.
// only support F32 and F16 on 310P.
return false;
}
#else
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) ||
(src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) {
// only support F32, F16 and BF16.
return false;
}
#endif
return true;
}
break;
case GGML_OP_CONT:
{
// TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
#ifndef ASCEND_310P
case GGML_TYPE_BF16:
#endif
return true;
default:
return false;
@ -2503,10 +2524,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
if (logitSoftcap != 0.0f) {

View File

@ -570,24 +570,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
set(KLEIDIAI_FETCH_ARGS
URL ${KLEIDIAI_DOWNLOAD_URL}
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
FetchContent_Declare(KleidiAI_Download
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
FetchContent_Declare(KleidiAI_Download
${KLEIDIAI_FETCH_ARGS}
EXCLUDE_FROM_ALL
)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_MakeAvailable(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
else()
FetchContent_Declare(KleidiAI_Download
${KLEIDIAI_FETCH_ARGS}
)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED
)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)

View File

@ -3194,6 +3194,7 @@ class tinyBLAS_PPC {
private:
__attribute__((always_inline))
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
@ -3204,6 +3205,7 @@ class tinyBLAS_PPC {
}
}
__attribute__((always_inline))
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);

View File

@ -45,6 +45,7 @@ static int opt_verbose = 0;
static int opt_profile = 0;
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_experimental = 0;
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
// Enable all stages by default
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
// Start the DSP-side service. We need to pass the queue ID to the
// DSP in a FastRPC call; the DSP side will import the queue and start
// listening for packets in a callback.
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
if (err != 0) {
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
opt_opsync = str_opsync ? atoi(str_opsync) : 0;
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {

View File

@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)

View File

@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity;
}
// ---------------------------------------------------------------------------
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
// transparently handles values that exceed the 16-bit limit and submits
// chained DMA transtions.
//
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
// single descriptor, preserving async DMA behavior.
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
// one at a time. The first N-1 rows are pushed+popped synchronously;
// the last row is left async so the caller can pop it.
// ---------------------------------------------------------------------------
#define UDMA_MAX_FIELD_VAL 65535u
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
// Fast path: everything fits in 16 bits.
if (__builtin_expect(
width <= UDMA_MAX_FIELD_VAL &&
nrows <= UDMA_MAX_FIELD_VAL &&
src_stride <= UDMA_MAX_FIELD_VAL &&
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
}
// Case 2: contiguous block (width == src_stride == dst_stride).
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
if (width == src_stride && width == dst_stride) {
size_t total = width * nrows;
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
while (sub_width > 0 && total % sub_width != 0) {
sub_width -= 128;
}
if (sub_width == 0) {
// Fallback: use original width (must fit) with adjusted nrows.
// This shouldn't happen for 128-aligned DMA sizes.
sub_width = width;
}
size_t sub_nrows = total / sub_width;
// Handle sub_nrows > 65535 by issuing chunked descriptors.
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
size_t rows_done = 0;
while (rows_done < sub_nrows) {
size_t chunk = sub_nrows - rows_done;
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
return false;
rows_done += chunk;
// Complete all chunks without waiting except the last one, so the
// caller's single dma_queue_pop drains the final descriptor.
if (rows_done < sub_nrows)
dma_queue_pop_nowait(q);
}
return true;
}
// Case 3: stride overflow — fall back to row-by-row.
{
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
for (size_t r = 0; r < nrows; ++r) {
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
src + r * src_stride);
if (!dma_queue_push(q, p, 0, 0, width, 1))
return false;
if (r + 1 < nrows)
dma_queue_pop_nowait(q);
}
return true;
}
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#ifndef restrict
# define restrict __restrict
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct htp_context; // forward declaration
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_w16a32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
const hmx_matmul_w16a32_batched_params_t *params);
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int weight_type);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H

View File

@ -0,0 +1,34 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H

View File

@ -0,0 +1,88 @@
// HMX tile-level inline helpers (FP16 32x32 tile operations).
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_UTILS_H
#define HMX_UTILS_H
#include <hexagon_types.h>
#include <stddef.h>
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
}
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
*pv++ = v_scale;
*pv = Q6_V_vzero();
}
// Load multiple contiguous tiles with :deep streaming.
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
// boundary, otherwise the mxmem instruction will raise a precise bus error.
// Callers must ensure their VTCM layout satisfies this constraint.
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
asm volatile(
"{ activation.hf = mxmem(%0, %1):deep\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
: "memory");
}
// Load a single activation+weight tile pair (no :deep streaming).
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
const __fp16 *wt_tile) {
asm volatile(
"{ activation.hf = mxmem(%0, %1)\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(act_tile), "r"(2047),
"r"(wt_tile), "r"(2047)
: "memory");
}
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
// Use the combined convert-and-store instruction (matches the reference
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
asm volatile(
"mxmem(%0, %1):after.hf = acc\n"
:: "r"(out), "r"(0)
: "memory");
}
// Compute inner product of two vectors of tiles and store result.
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
hmx_consume_accumulator_fp16(out);
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // HMX_UTILS_H

View File

@ -30,6 +30,12 @@ struct htp_context {
atomic_bool vtcm_needs_release;
uint32_t opmask;
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
#ifdef HTP_HAS_HMX
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
#endif
};
#endif /* HTP_CTX_H */

View File

@ -32,13 +32,14 @@ enum htp_status {
// Duplicated here because we can't include full ggml.h in the htp build.
// We have some static_asserts in the cpp code to ensure things are in sync.
enum htp_data_type {
HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_IQ4_NL = 20,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT
};
@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
return QK4_0;
case HTP_TYPE_Q8_0:
return QK8_0;
case HTP_TYPE_IQ4_NL:
return QK4_NL;
case HTP_TYPE_MXFP4:
return QK_MXFP4;
default:
@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return sizeof(block_q4_0);
case HTP_TYPE_Q8_0:
return sizeof(block_q8_0);
case HTP_TYPE_IQ4_NL:
return sizeof(block_iq4_nl);
case HTP_TYPE_MXFP4:
return sizeof(block_mxfp4);
default:

View File

@ -7,7 +7,7 @@
#include "remote.idl"
interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
AEEResult stop();
AEEResult enable_etm();
AEEResult disable_etm();

View File

@ -9,6 +9,9 @@
#include "hex-utils.h"
#include "hvx-types.h"
#define hvx_vmem(A) *((HVX_Vector *)(A))
#define hvx_vmemu(A) *((HVX_UVector *)(A))
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
// Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
return Q6_Q_and_QQ(p_exp, p_frac);
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vsplat_R(0);
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
return v;
}
#if __HVX_ARCH__ >= 79
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
#else
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)

View File

@ -25,6 +25,10 @@
#include "htp-ops.h"
#include "worker-pool.h"
#ifdef HTP_HAS_HMX
#include "hmx-ops.h"
#endif // HTP_HAS_HMX
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
struct htp_context * ctx;
int err = 0;
@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
}
ctx->vtcm_inuse = true;
return 0;
}
@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) {
@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY;
}
#ifdef HTP_HAS_HMX
if (use_hmx) {
ctx->vtcm_scratch_size = ctx->vtcm_size;
ctx->hmx_enabled = 1;
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
} else {
// HMX disabled: skip HMX initialisation so the
// dispatch loop falls through to the HVX compute paths.
ctx->hmx_enabled = 0;
ctx->vtcm_scratch_size = ctx->vtcm_size;
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
}
#endif
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
for (int i = 0; i < ctx->n_threads; i++) {
dma_queue_delete(ctx->dma[i]);
}
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
ctx->hmx_enabled = 0;
}
#endif
vtcm_free(ctx);
@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
struct dspqueue_buffer * bufs,
size_t n_bufs,
struct profile_data * prof) {
// Prep response struct
// Prep response struct (zero-init to clear cmp/unused union)
struct htp_general_rsp rsp;
memset(&rsp, 0, sizeof(rsp));
rsp.op = op;
rsp.status = status;
rsp.prof_usecs = prof->usecs;
@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
}
#ifdef HTP_HAS_HMX
// ---------------------------------------------------------------------------
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
// ---------------------------------------------------------------------------
static void proc_hmx_matmul_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// HMX weight tile requires N to be 32-aligned.
if (req->src0.ne[1] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
req->src1.ne[2] * req->src1.ne[3] > 1);
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
// batched quantised, but guard here too). F16 batched matmul is handled
// by the dedicated wrapper in hmx-matmul-ops.c.
if (is_batched &&
req->src0.type != HTP_TYPE_F16) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX assumes contiguous row-major layout. Fall back for permuted
// tensors where strides are non-monotonic (e.g. transposed KV cache).
if (req->src0.nb[0] > req->src0.nb[1] ||
req->src1.nb[0] > req->src1.nb[1]) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
const int m_total = (int) req->src1.ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
if (m_hmx == 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
// Other types (e.g. MXFP4) fall back to HVX.
{
uint32_t wtype = req->src0.type;
if (wtype != HTP_TYPE_F16 &&
wtype != HTP_TYPE_Q4_0 &&
wtype != HTP_TYPE_Q8_0 &&
wtype != HTP_TYPE_IQ4_NL) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
// F16 HMX path requires K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
}
(void) n_bufs;
struct dspqueue_buffer rsp_bufs[1];
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
// src0 = weights, src1 = activation, dst = output
void * wgt = (void *) bufs[0].ptr;
float * act = (float *) bufs[1].ptr;
float * dst = (float *) bufs[2].ptr;
int k = (int) req->src0.ne[0]; // inner dimension
int n = (int) req->src0.ne[1]; // weight columns
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
int ret = -1;
const int ne02 = (int) req->src0.ne[2];
const int ne03 = (int) req->src0.ne[3];
const int ne12 = (int) req->src1.ne[2];
const int ne13 = (int) req->src1.ne[3];
// Row strides in elements. For compact tensors these equal k; for
// permuted attention views they can be larger, so pass the real stride.
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
switch (req->src0.type) {
case HTP_TYPE_F16:
if (is_batched) {
hmx_matmul_w16a32_batched_params_t batch_params = {
.dst = dst,
.activation = act,
.permuted_weight = (const __fp16 *) wgt,
.m = m_hmx,
.k = k,
.n = n,
.act_stride = act_stride,
.weight_stride = weight_stride,
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
.ne02 = ne02,
.ne03 = ne03,
.ne12 = ne12,
.ne13 = ne13,
.src0_nb2 = req->src0.nb[2],
.src0_nb3 = req->src0.nb[3],
.src1_nb2 = req->src1.nb[2],
.src1_nb3 = req->src1.nb[3],
.dst_nb2 = req->dst.nb[2],
.dst_nb3 = req->dst.nb[3],
};
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
} else {
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
(const __fp16 *) wgt,
m_hmx, k, n,
act_stride,
weight_stride);
}
break;
default:
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
(const uint8_t *) wgt,
m_hmx, k, n, (int) req->src0.type);
break;
}
if (ret == 0) {
rsp_status = HTP_STATUS_OK;
} else {
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
vtcm_release(ctx);
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
vtcm_release(ctx);
}
// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0; // weights: unchanged
octx.src1 = req->src1;
octx.src1.ne[1] = m_tail; // only tail rows
octx.dst = req->dst;
octx.dst.ne[1] = m_tail; // only tail rows
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
// is invalid.
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
octx.op = req->op;
octx.n_threads = ctx->n_threads;
// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
uint32_t hvx_ret = op_matmul(&octx);
vtcm_release(ctx);
if (hvx_ret != HTP_STATUS_OK) {
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
} else {
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
#endif // HTP_HAS_HMX
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context;
@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
FARF(ERROR, "Bad matmul-req buffer list");
continue;
}
proc_matmul_req(ctx, &req, bufs, n_bufs);
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
} else
#endif
{
proc_matmul_req(ctx, &req, bufs, n_bufs);
}
break;
case HTP_OP_MUL_MAT_ID:

View File

@ -53,9 +53,6 @@ endif()
message(STATUS "HIP and hipBLAS found")
# Workaround old compilers
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
@ -132,6 +129,11 @@ endif()
if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug")
# CMake on Windows doesn't support the HIP language yet.
# Therefore we workaround debug build's failure on HIP backend this way.
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g")
endif()
target_link_libraries(ggml-hip PRIVATE hip::device)
else()
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)

View File

@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
const bool use_subgroup_reduce = device->subgroup_arithmetic;
for (uint32_t si = 0; si < 3; si++) {
const uint32_t S_V = gdn_sizes[si];
GGML_ASSERT(is_pow2(S_V));
uint32_t lanes_per_column;
if (S_V >= 128u && device->subgroup_clustered) {
lanes_per_column = 8u;
} else {
// Use largest power-of-two that divides both S_V and subgroup_size so that
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
// This means we don't need extra bounds checking logic in the shader.
lanes_per_column = std::min(S_V, device->subgroup_size);
}
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
size_t gdn_len;
const void * gdn_data;
if (use_subgroup_reduce && need_clustered_shader) {
gdn_len = gated_delta_net_f32_len;
gdn_data = (const void *)gated_delta_net_f32_data;
} else if (use_subgroup_reduce) {
gdn_len = gated_delta_net_f32_nocluster_len;
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
} else {
gdn_len = gated_delta_net_f32_shmem_len;
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
}
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
}
}
}
@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
pc, { H, n_seqs, S_v });
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {

View File

@ -1,11 +1,25 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
shared FLOAT_TYPE s_k[S_V];
shared FLOAT_TYPE s_q[S_V];
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];
// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
const uint lane = gl_SubgroupInvocationID;
temp[lane] = partial;
barrier();
[[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
FLOAT_TYPE other = temp[lane ^ s];
barrier();
temp[lane] += other;
barrier();
}
const FLOAT_TYPE result = temp[lane];
barrier();
return result;
}
#endif
// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
switch (LANES_PER_COLUMN) {
case 1u:
return partial;
#if USE_SUBGROUP_CLUSTERED
// Workaround for GLSL requiring a literal constant for the cluster size.
// The branches should all fold away.
case 2u:
return subgroupClusteredAdd(partial, 2u);
case 4u:
return subgroupClusteredAdd(partial, 4u);
case 8u:
return subgroupClusteredAdd(partial, 8u);
case 16u:
return subgroupClusteredAdd(partial, 16u);
case 32u:
return subgroupClusteredAdd(partial, 32u);
case 64u:
return subgroupClusteredAdd(partial, 64u);
#endif
default:
#if USE_SUBGROUP_ADD
return subgroupAdd(partial);
#else
return reduce_add_shmem(partial);
#endif
}
}
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
@ -42,9 +103,9 @@ void main() {
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@ -53,76 +114,56 @@ void main() {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
if (KDA != 0) {
const uint g_base = gb_off * S_V;
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
}
barrier();
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
FLOAT_TYPE k_reg[ROWS_PER_LANE];
FLOAT_TYPE q_reg[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
}
FLOAT_TYPE g_exp[ROWS_PER_LANE];
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
kv_col += dot(
vec4(state[i], state[i+1], state[i+2], state[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
g_exp[r] = g_val;
}
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = g_val * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
} else {
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
kv_col += dot(gv * sv, kv);
const uint g_base = gb_off * S_V;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
const uint i = r * LANES_PER_COLUMN + lane;
g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
}
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = gv * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
FLOAT_TYPE kv_shard = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
}
FLOAT_TYPE kv_col = reduce_partial(kv_shard);
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_partial = 0.0;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
FLOAT_TYPE attn_col = reduce_partial(attn_partial);
if (lane == 0) {
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
barrier();
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + col * S_V + i] = state[i];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}

View File

@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint iq = 16 * ib32 + 2 * (idx % 8);
const uint ib = idx / 64; // 4 values per idx
const uint ib32 = (idx % 64) / 8; // 0..7
const uint iq = 4 * ib32 + (idx % 4);
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 8) >> 1;
u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
const uint qshift = idx & 4;
u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F);
const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;

View File

@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
if (tname == "bf16") {
@ -987,7 +987,9 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}}));
string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
};
struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
};
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
uint32_t wg_size;
};
/** Set **/
struct ggml_webgpu_set_pipeline_key {
ggml_type type;
bool inplace;
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
return type == other.type && inplace == other.inplace;
}
};
struct ggml_webgpu_set_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Get Rows **/
struct ggml_webgpu_get_rows_pipeline_key {
@ -151,6 +176,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
}
};
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
}
};
struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
@ -166,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
}
};
/** Solve Tri **/
struct ggml_webgpu_solve_tri_pipeline_key {
int type;
int n;
int k;
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
return type == other.type && n == other.n && k == other.k;
}
};
struct ggml_webgpu_solve_tri_pipeline_key_hash {
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.n);
ggml_webgpu_hash_combine(seed, key.k);
return seed;
}
};
/** SSM Conv **/
struct ggml_webgpu_ssm_conv_pipeline_key {
int type;
int vectorized;
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
return type == other.type && vectorized == other.vectorized;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
int s_v;
int kda;
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
return type == other.type && s_v == other.s_v && kda == other.kda;
}
};
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.s_v);
ggml_webgpu_hash_combine(seed, key.kda);
return seed;
}
};
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
@ -244,13 +350,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
/** Unary **/
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
ggml_tri_type ttype; // only used for GGML_OP_TRI
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
ttype == other.ttype;
}
};
@ -261,6 +369,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.ttype);
return seed;
}
};
@ -435,20 +544,30 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
solve_tri_pipelines; // type
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
ssm_conv_pipelines; // type/vectorized
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
webgpu_pipeline,
ggml_webgpu_gated_delta_net_pipeline_key_hash>
gated_delta_net_pipelines; // type/S_v/kda
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
pad_pipelines; // circular/non-circular
pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines; // type/op/inplace/overlap
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@ -462,6 +581,7 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -479,6 +599,45 @@ class ggml_webgpu_shader_lib {
return sum_rows_pipelines[1];
}
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {
.op = context.dst->op,
.inplace = context.inplace,
};
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant;
switch (key.op) {
case GGML_OP_RMS_NORM:
defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("L2_NORM");
variant = "l2_norm";
break;
default:
GGML_ABORT("Unsupported op for row_norm shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
}
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0;
@ -546,6 +705,46 @@ class ggml_webgpu_shader_lib {
return set_rows_pipelines[key];
}
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "set";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported type for set shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_set, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
set_pipelines[key] = pipeline;
return set_pipelines[key];
}
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = cumsum_pipelines.find(1);
if (it != cumsum_pipelines.end()) {
@ -632,6 +831,7 @@ class ggml_webgpu_shader_lib {
switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("FLOAT_PARALLEL");
if (key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
@ -646,6 +846,7 @@ class ggml_webgpu_shader_lib {
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
@ -653,6 +854,7 @@ class ggml_webgpu_shader_lib {
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
@ -731,6 +933,128 @@ class ggml_webgpu_shader_lib {
return scale_pipelines[key];
}
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_solve_tri_pipeline_key key = {
.type = context.dst->type,
.n = (int) context.src0->ne[0],
.k = (int) context.src1->ne[0],
};
auto it = solve_tri_pipelines.find(key);
if (it != solve_tri_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "solve_tri";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for solve_tri shader");
}
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
const uint32_t k_tile = wg_size;
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
defines.push_back(std::string("N=") + std::to_string(key.n));
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
solve_tri_pipelines[key] = pipeline;
return solve_tri_pipelines[key];
}
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_ssm_conv_pipeline_key key = {
.type = context.dst->type,
.vectorized = context.src1->ne[0] == 4,
};
auto it = ssm_conv_pipelines.find(key);
if (it != ssm_conv_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "ssm_conv";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for ssm_conv shader");
}
if (key.vectorized) {
defines.push_back("VECTORIZED");
variant += "_vec4";
}
constexpr uint32_t block_size = 32u;
constexpr uint32_t tokens_per_wg = 8u;
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
decisions->block_size = block_size;
decisions->tokens_per_wg = tokens_per_wg;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
ssm_conv_pipelines[key] = pipeline;
return ssm_conv_pipelines[key];
}
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_gated_delta_net_pipeline_key key = {
.type = context.dst->type,
.s_v = (int) context.src2->ne[0],
.kda = context.src3->ne[0] == context.src2->ne[0],
};
auto it = gated_delta_net_pipelines.find(key);
if (it != gated_delta_net_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "gated_delta_net";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for gated_delta_net shader");
}
if (key.kda) {
defines.push_back("KDA");
variant += "_kda";
}
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
gated_delta_net_pipelines[key] = pipeline;
return gated_delta_net_pipelines[key];
}
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
@ -1058,6 +1382,7 @@ class ggml_webgpu_shader_lib {
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
};
auto it = unary_pipelines.find(key);
@ -1088,6 +1413,29 @@ class ggml_webgpu_shader_lib {
variant += "_inplace";
}
if (op == GGML_OP_TRI) {
switch (key.ttype) {
case GGML_TRI_TYPE_LOWER:
defines.push_back("TRI_TYPE_LOWER");
variant += "_tri_type_lower";
break;
case GGML_TRI_TYPE_LOWER_DIAG:
defines.push_back("TRI_TYPE_LOWER_DIAG");
variant += "_tri_type_lower_diag";
break;
case GGML_TRI_TYPE_UPPER:
defines.push_back("TRI_TYPE_UPPER");
variant += "_tri_type_upper";
break;
case GGML_TRI_TYPE_UPPER_DIAG:
defines.push_back("TRI_TYPE_UPPER_DIAG");
variant += "_tri_upper_diag";
break;
default:
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
}
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines);

View File

@ -366,7 +366,6 @@ struct webgpu_context_struct {
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
@ -881,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1u,
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries;
uint32_t binding_index = 0;
if (!inplace) {
entries.push_back({ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
binding_index++;
}
entries.push_back({ .binding = binding_index,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
entries.push_back({ .binding = binding_index + 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
@ -936,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
token_tiles,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * src3,
ggml_tensor * src4,
ggml_tensor * src5,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.src3 = src3,
.src4 = src4,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
const uint32_t s_v = (uint32_t) src2->ne[0];
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
std::vector<uint32_t> params = {
h,
n_tokens,
n_seqs,
s_v * h * n_tokens * n_seqs,
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
(uint32_t) src0->ne[1],
(uint32_t) (src2->ne[3] / src0->ne[3]),
scale_u32,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
{ .binding = 3,
.buffer = ggml_webgpu_tensor_buf(src3),
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
{ .binding = 4,
.buffer = ggml_webgpu_tensor_buf(src4),
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
{ .binding = 5,
.buffer = ggml_webgpu_tensor_buf(src5),
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
{ .binding = 6,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
}
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
@ -1017,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
@ -1061,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1598,8 +1866,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@ -1630,8 +1898,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
entries, ggml_nrows(src));
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
}
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@ -2170,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_CPY:
case GGML_OP_CONT:
return ggml_webgpu_cpy(ctx, src0, node);
case GGML_OP_SET:
return ggml_webgpu_set(ctx, src0, src1, node);
case GGML_OP_SET_ROWS:
return ggml_webgpu_set_rows(ctx, src0, src1, node);
case GGML_OP_GET_ROWS:
@ -2192,7 +2469,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
case GGML_OP_GLU:
@ -2209,7 +2487,15 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_DIAG:
case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SOLVE_TRI:
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
case GGML_OP_SSM_CONV:
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
case GGML_OP_GATED_DELTA_NET:
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_ARGMAX:
@ -2614,15 +2900,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
webgpu_ctx->rms_norm_pipelines[0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
@ -2907,7 +3184,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@ -2958,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
dev,
/* .context = */ NULL
};
@ -3041,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
break;
case GGML_OP_SET:
supports_op = src0->type == src1->type && src0->type == op->type &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
break;
case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
@ -3118,6 +3398,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:
@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
}
break;
case GGML_OP_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_DIAG:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_SOLVE_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t s_v = (uint32_t) src2->ne[0];
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
}
break;
case GGML_OP_CLAMP:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;

View File

@ -0,0 +1,132 @@
@group(0) @binding(0)
var<storage, read_write> src_q: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src_k: array<f32>;
@group(0) @binding(2)
var<storage, read_write> src_v: array<f32>;
@group(0) @binding(3)
var<storage, read_write> src_g: array<f32>;
@group(0) @binding(4)
var<storage, read_write> src_beta: array<f32>;
@group(0) @binding(5)
var<storage, read_write> src_state: array<f32>;
@group(0) @binding(6)
var<storage, read_write> dst: array<f32>;
struct Params {
h: u32,
n_tokens: u32,
n_seqs: u32,
s_off: u32,
sq1: u32,
sq2: u32,
sq3: u32,
sv1: u32,
sv2: u32,
sv3: u32,
sb1: u32,
sb2: u32,
sb3: u32,
neq1: u32,
rq3: u32,
scale: f32,
};
@group(0) @binding(7)
var<uniform> params: Params;
var<workgroup> sh_k: array<f32, S_V>;
var<workgroup> sh_q: array<f32, S_V>;
var<workgroup> sh_g: array<f32, S_V>;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let head_id = workgroup_id.x;
let seq_id = workgroup_id.y;
let col = local_id.x;
let iq1 = head_id % params.neq1;
let iq3 = seq_id / params.rq3;
let state_size = S_V * S_V;
let state_base = (seq_id * params.h + head_id) * state_size;
var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
state[i] = src_state[state_base + col * S_V + i];
}
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
for (var t = 0u; t < params.n_tokens; t++) {
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
let k_off = q_off;
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
sh_q[col] = src_q[q_off + col];
sh_k[col] = src_k[k_off + col];
#ifdef KDA
let g_base = gb_off * S_V;
sh_g[col] = exp(src_g[g_base + col]);
#endif
workgroupBarrier();
let v_val = src_v[v_off + col];
let beta_val = src_beta[gb_off];
var kv_col = 0.0;
var delta_col = 0.0;
var attn_col = 0.0;
#ifdef KDA
for (var i = 0u; i < S_V; i++) {
kv_col += (sh_g[i] * state[i]) * sh_k[i];
}
delta_col = (v_val - kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#else
let g_val = exp(src_g[gb_off]);
for (var i = 0u; i < S_V; i++) {
kv_col += state[i] * sh_k[i];
}
delta_col = (v_val - g_val * kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = g_val * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#endif
dst[attn_off + col] = attn_col * params.scale;
attn_off += S_V * params.h;
workgroupBarrier();
}
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_base + col * S_V + i] = state[i];
}
}

View File

@ -640,6 +640,35 @@ var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef FLOAT_PARALLEL
let blocks_per_row = params.ne0 / BLOCK_SIZE;
let row_count = params.n_rows * params.ne2 * params.ne3;
if (gid.x >= blocks_per_row * row_count) {
return;
}
let block_idx = gid.x % blocks_per_row;
var row_idx = gid.x / blocks_per_row;
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
row_idx = row_idx % (params.ne2 * params.n_rows);
let i_dst2 = row_idx / params.n_rows;
let i_dst1 = row_idx % params.n_rows;
let i_idx2 = i_dst3 % params.idx2;
let i_idx1 = i_dst2 % params.idx1;
let i_idx0 = i_dst1;
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = u32(idx[i_idx]);
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
copy_elements(i_src_row, i_dst_row, block_idx);
#else
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
}
@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
#endif
}

View File

@ -1,21 +1,11 @@
#define(VARIANTS)
[
{
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_SUFFIX": "inplace",
"DECLS": ["INPLACE"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
#ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
@ -25,23 +15,7 @@ var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
#endif
struct Params {
offset_src: u32, // in elements
@ -68,12 +42,9 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var sum = 0.0f;
var col = lid.x;
@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset = wg_size / 2;
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
@ -110,14 +81,18 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
}
sum = scratch[0];
#ifdef RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += wg_size;
col += WG_SIZE;
}
}
#end(SHADER)

View File

@ -0,0 +1,109 @@
#ifdef TYPE_I32
#define TYPE i32
#else
#define TYPE f32
#endif
#ifndef INPLACE
@group(0) @binding(0)
var<storage, read_write> src0: array<TYPE>;
#define SRC1_BINDING 1
#else
#define SRC1_BINDING 0
#endif
#define DST_BINDING SRC1_BINDING + 1
#define PARAMS_BINDING SRC1_BINDING + 2
@group(0) @binding(SRC1_BINDING)
var<storage, read_write> src1: array<TYPE>;
@group(0) @binding(DST_BINDING)
var<storage, read_write> dst: array<TYPE>;
struct Params {
ne: u32,
offset_src0: u32,
offset_src1: u32,
offset_view: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst10: u32,
stride_dst11: u32,
stride_dst12: u32,
stride_dst13: u32,
src1_ne0: u32,
src1_ne1: u32,
src1_ne2: u32,
src1_ne3: u32,
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
fn decode_src1_coords(idx: u32) -> vec4<u32> {
var i = idx;
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
let i3 = i / plane;
i = i % plane;
let row = params.src1_ne1 * params.src1_ne0;
let i2 = i / row;
i = i % row;
let i1 = i / params.src1_ne0;
let i0 = i % params.src1_ne0;
return vec4<u32>(i0, i1, i2, i3);
}
fn decode_view_coords(rel: u32) -> vec4<u32> {
let i3 = rel / params.stride_dst13;
let rem3 = rel % params.stride_dst13;
let i2 = rem3 / params.stride_dst12;
let rem2 = rem3 % params.stride_dst12;
let i1 = rem2 / params.stride_dst11;
let i0 = rem2 % params.stride_dst11;
return vec4<u32>(i0, i1, i2, i3);
}
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
}
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
coords.z * params.stride_src12 + coords.w * params.stride_src13;
}
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
return view_rel_from_coords(coords) == rel;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
#ifdef INPLACE
let coords = decode_src1_coords(gid.x);
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
let dst_idx = params.offset_view + view_rel_from_coords(coords);
dst[dst_idx] = src1[src1_idx];
#else
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
let coords = decode_view_coords(rel);
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
} else {
dst[gid.x] = src0[params.offset_src0 + gid.x];
}
#endif
}

View File

@ -0,0 +1,121 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src00: u32,
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
k: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
var<workgroup> shA: array<f32, BATCH_N * N>;
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src0 +
col * params.stride_src00 +
row * params.stride_src01 +
i2 * params.stride_src02 +
i3 * params.stride_src03;
}
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src1 +
col * params.stride_src10 +
row * params.stride_src11 +
i2 * params.stride_src12 +
i3 * params.stride_src13;
}
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_dst +
col * params.stride_dst0 +
row * params.stride_dst1 +
i2 * params.stride_dst2 +
i3 * params.stride_dst3;
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let batch = workgroup_id.y;
let col = workgroup_id.x * WG_SIZE + local_id.x;
let i3 = batch / params.ne2;
let i2 = batch % params.ne2;
let active_lane = local_id.x < K_TILE;
let active_col = active_lane && col < params.k;
var X: array<f32, N>;
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
let cur_n = min(BATCH_N, N - row_base);
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
let tile_row = i / N;
let tile_col = i % N;
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
}
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
let tile_row = i / K_TILE;
let tile_col = i % K_TILE;
let global_col = workgroup_id.x * WG_SIZE + tile_col;
let sh_idx = tile_row * K_TILE + tile_col;
if (global_col < params.k) {
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
} else {
shB[sh_idx] = 0.0;
}
}
workgroupBarrier();
if (active_col) {
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
let r = row_base + row_offset;
var b = shB[row_offset * K_TILE + local_id.x];
let a_row = row_offset * N;
for (var t = 0u; t < r; t++) {
b -= shA[a_row + t] * X[t];
}
let x = b / shA[a_row + r];
X[r] = x;
dst[dst_idx(r, col, i2, i3)] = x;
}
}
workgroupBarrier();
}
}

View File

@ -0,0 +1,65 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src01: u32,
stride_src02: u32,
stride_src11: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
nc: u32,
nr: u32,
n_t: u32,
n_s: u32,
token_tiles: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i1 = gid.x;
let tile_y = gid.y / TOKENS_PER_WG;
let local_token = gid.y % TOKENS_PER_WG;
let i3 = tile_y / params.token_tiles;
let token_tile = tile_y % params.token_tiles;
let i2 = token_tile * TOKENS_PER_WG + local_token;
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
return;
}
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
let src1_base = params.offset_src1 + i1 * params.stride_src11;
var sum = 0.0;
#ifdef VECTORIZED
sum =
src0[src0_base + 0u] * src1[src1_base + 0u] +
src0[src0_base + 1u] * src1[src1_base + 1u] +
src0[src0_base + 2u] * src1[src1_base + 2u] +
src0[src0_base + 3u] * src1[src1_base + 3u];
#else
for (var i0 = 0u; i0 < params.nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
#endif
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
dst[dst_idx] = sum;
}

View File

@ -5,7 +5,6 @@ enable f16;
#define TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<TYPE>;
@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let ne2 = params.ne2;
#ifdef DIAG
let ne1 = params.ne0;
#else
let ne1 = params.ne1;
#endif
let ne0 = params.ne0;
let i3 = i / (ne2 * ne1 * ne0);
i = i % (ne2 * ne1 * ne0);
let i2 = i / (ne1 * ne0);
i = i % (ne1 * ne0);
let i1 = i / ne0;
let i0 = i % ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32);
#endif
#ifdef DIAG
let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
#endif
#ifdef TRI
#ifdef TRI_TYPE_LOWER
let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
#elif TRI_TYPE_LOWER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
#elif TRI_TYPE_UPPER
let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
#elif TRI_TYPE_UPPER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
#endif
#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;

View File

@ -425,8 +425,7 @@ class GGUFWriter:
fout = self.fout[file_id]
# pop the first tensor info
# TODO: cleaner way to get the first key
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
first_tensor_name = next(iter(self.tensors[file_id]))
ti = self.tensors[file_id].pop(first_tensor_name)
assert ti.nbytes == tensor.nbytes

View File

@ -7,7 +7,6 @@
{%- set available_tool_string = '' -%}
{%- set add_tool_id = true -%}
{%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #}
{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #}
{# Optional token placeholders (safe defaults) #}
{%- set bos_token = bos_token or '' -%}
{%- set eos_token = eos_token or '' -%}

View File

@ -15,10 +15,10 @@
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls']-%}
{%- if not ns.is_first -%}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- set ns.is_first = true -%}
{%- else -%}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- endif -%}
{%- endfor -%}
{{'<tool▁calls▁end><end▁of▁sentence>'}}

View File

@ -28,25 +28,25 @@
{%- set ns.is_last_user = true -%}{{'<User>' + message['content']}}
{%- endif -%}
{%- if message['role'] == 'assistant' and message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant></think>'}}
{%- if ns.is_last_user -%}{{'<Assistant><think></think>'}}
{%- endif -%}
{%- set ns.is_last_user = false -%}
{%- set ns.is_first = false -%}
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls'] -%}
{%- if not ns.is_first -%}
{%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%}
{%- set ns.is_first = true -%}
{%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}
{%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%}
{%- endfor -%}{{'<tool▁calls▁end><end▁of▁sentence>'}}
{%- endif -%}
{%- if message['role'] == 'assistant' and not message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant>'}}
{%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}}
{%- else -%}{{'</think>'}}
{%- else -%}{{'<think></think>'}}
{%- endif -%}
{%- endif -%}
{%- set ns.is_last_user = false -%}
@ -65,7 +65,7 @@
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<Assistant>'}}
{%- if not thinking -%}{{'</think>'}}
{%- else -%}{{'<think>'}}
{%- if not thinking -%}{{'<think></think>'}}
{%- else -%}{{'<think>'}}
{%- endif -%}
{%- endif %}

View File

@ -49,7 +49,7 @@ Example function tool call syntax:
{%- endif -%}
{%- set tool_name = tc['function']['name'] -%}
{%- set tool_args = tc['function']['arguments'] -%}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<tool▁call▁end>' -}}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args | tojson + '\n' + '```' + '<tool▁call▁end>' -}}
{%- endfor -%}
{{- '<tool▁calls▁end><end▁of▁sentence>' -}}
{%- endif -%}

View File

@ -42,9 +42,9 @@
{%- if 'tool_calls' in message and message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- if tool_call["function"]["name"] == "python" -%}
{{ '<|python_tag|>' + tool_call['function']['arguments'] }}
{{ '<|python_tag|>' + tool_call['function']['arguments'] | tojson }}
{%- else -%}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] + '</function>' }}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] | tojson + '</function>' }}
{%- endif -%}
{%- endfor -%}
{{ '<|eom_id|>' }}

View File

@ -95,9 +95,9 @@ if __name__ == '__main__':
'-p', 'Hey',
'--no-warmup',
'--log-disable',
'-no-cnv']
'-st']
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
cmd.append('-fa')
cmd += ('-fa', 'on')
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:

View File

@ -0,0 +1,157 @@
#!/usr/bin/env python3
import sys
from collections import defaultdict
def parse_log_file(filepath):
"""Parse log file and extract function VGPR usage."""
import re
functions = defaultdict(lambda: {'vgprs': 0, 'spill': 0, 'location': ''})
try:
with open(filepath, 'r') as f:
content = f.read()
# Find all function entries with VGPR usage including location
pattern = r'([^:]+:\d+):.*?Function Name: (\S+).*?VGPRs: (\d+).*?VGPRs Spill: (\d+)'
matches = re.findall(pattern, content, re.DOTALL)
for location, func_name, vgprs, spill in matches:
functions[func_name]['vgprs'] = int(vgprs)
functions[func_name]['spill'] = int(spill)
# Extract just the filename and line number
parts = location.split('/')
if len(parts) > 0:
short_location = parts[-1] # Get last part (filename)
# Check if there's a line number after filename
if ':' in short_location:
functions[func_name]['location'] = short_location
else:
functions[func_name]['location'] = location
else:
functions[func_name]['location'] = location
except FileNotFoundError:
print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100
sys.exit(1)
return functions
def main():
if len(sys.argv) < 2:
print("Usage: ./vgpr_check.py <log_file>", file=sys.stderr) # noqa: NP100
sys.exit(1)
log_file = sys.argv[1]
ignored = {
'_ZL21gated_linear_attn_f32ILi128EEviiiifPKfS1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL13rwkv_wkv7_f32ILi128EEviiiiPKfS1_S1_S1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type22ELi8ELb1EEvPKiS2_PfPKfiiimimimi',
'_ZL9mul_mat_qIL9ggml_type3ELi32ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi48ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type20ELi32ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi64ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL15flash_attn_tileILi256ELi256ELi32ELi1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type19ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type22ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type11ELi128ELb0EEvPKiS2_PfPKfiiimimimi',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type2ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
}
functions = parse_log_file(log_file)
found_issues = False
# First print all ignored functions (deduplicated)
printed_ignored = set()
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name in ignored and func_name not in printed_ignored:
location = data.get('location', log_file)
print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100
printed_ignored.add(func_name)
# Then print new functions with issues in red
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name not in ignored:
status = "[IGNORED]" if func_name in ignored else ""
location = data.get('location', log_file)
# Print in red if not ignored
color_code = "\033[91m" if func_name not in ignored else ""
reset_code = "\033[0m" if func_name not in ignored else ""
print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100
if func_name not in ignored:
found_issues = True
sys.exit(1 if found_issues else 0)
if __name__ == "__main__":
main()

View File

@ -39,6 +39,9 @@ opmask=
nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \

View File

@ -39,6 +39,9 @@ opmask=
nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \

View File

@ -45,6 +45,9 @@ opmask=
nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -58,7 +61,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
$verbose $experimental $sched $opmask $profile $hmx $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \

View File

@ -36,6 +36,9 @@ opmask=
nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -50,5 +53,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb ./$branch/bin/$tool $@ \
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
"

View File

@ -1946,6 +1946,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
return 0;
}
ggml_backend_buffer_clear(buf_output.get(), 0);
}
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());

View File

@ -1673,6 +1673,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// NextN/MTP parameters (GLM-OCR)
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
@ -1706,6 +1707,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
@ -1752,6 +1754,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
@ -1926,6 +1929,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_30B_A3B; break;
@ -2054,7 +2058,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (hparams.n_embd) {
case 768: type = LLM_TYPE_350M; break;
case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break;
case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break;
case 2048: case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN;
@ -2108,6 +2112,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;

View File

@ -2129,19 +2129,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
}
const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
const float * scores = nullptr;
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
if (score_idx != -1) {
const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx);
if (n_scores < n_tokens) {
throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n");
}
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
}
const int * toktypes = nullptr;
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
if (toktype_idx != -1) {
const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx);
if (n_toktypes < n_tokens) {
throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n");
}
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
}
uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
id_to_token.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {

View File

@ -121,6 +121,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}

View File

@ -111,8 +111,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_
}
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = build_norm(inpL,

View File

@ -86,6 +86,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}

View File

@ -82,6 +82,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap
cur = ggml_add(ctx0, cur, ffn_inp);
// input for next layer
inpL = cur;
}
cur = inpL;

View File

@ -66,8 +66,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = build_norm(inpL,
model.output_norm,

View File

@ -362,6 +362,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;

View File

@ -177,6 +177,9 @@ llm_build_lfm2<iswa>::llm_build_lfm2(const llama_model & model, const llm_graph_
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
}
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);

View File

@ -71,6 +71,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, residual);
cb(cur, "ffn_residual", il);
// input for next layer
inpL = cur;
}

View File

@ -109,6 +109,8 @@ llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_gr
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}

View File

@ -64,6 +64,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_ffn", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer
inpL = cur;
}

View File

@ -64,6 +64,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_moe", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer
inpL = cur;
}

View File

@ -56,6 +56,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_moe", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer
inpL = cur;
}

View File

@ -101,6 +101,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model,
cur = ffn_out;
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);

View File

@ -145,9 +145,11 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}

View File

@ -22,6 +22,7 @@ static void test_calculate_diff_split_no_common(testing & t);
static void test_calculate_diff_split_single_char(testing & t);
static void test_calculate_diff_split_overlaps(testing & t);
static void test_calculate_diff_split_tag_boundaries(testing & t);
static void test_calculate_diff_split_generation_prompt(testing & t);
static void test_calculate_diff_split(testing & t);
static void test_until_common_prefix_basic(testing & t);
@ -179,6 +180,7 @@ static void test_calculate_diff_split(testing & t) {
t.test("calculate_diff_split single char", test_calculate_diff_split_single_char);
t.test("calculate_diff_split overlaps", test_calculate_diff_split_overlaps);
t.test("calculate_diff_split tag boundaries", test_calculate_diff_split_tag_boundaries);
t.test("calculate_diff_split generation prompt", test_calculate_diff_split_generation_prompt);
}
static void test_calculate_diff_split_basic(testing & t) {
@ -502,6 +504,39 @@ static void test_calculate_diff_split_tag_boundaries(testing & t) {
}
}
static void test_calculate_diff_split_generation_prompt(testing & t) {
// ChatML thinking template: left is a prefix of right, generation_prompt is the appended part.
// The trailing \n in left matches the trailing \n in the generation_prompt, causing
// the suffix matcher to steal it and rotate the diff result.
{
// Simplified reproduction: left ends with \n, right = left + "<|im_start|>assistant\n<think>\n"
std::string left = "<|im_start|>user\nHello<|im_end|>\n";
std::string right = left + "<|im_start|>assistant\n<think>\n";
diff_split result = calculate_diff_split(left, right);
t.assert_equal("chatml prefix", left, result.prefix);
t.assert_equal("chatml left", "", result.left);
t.assert_equal("chatml right should be generation prompt",
"<|im_start|>assistant\n<think>\n", result.right);
t.assert_equal("chatml suffix", "", result.suffix);
}
{
// More realistic: longer conversation ending with tool_response
std::string common =
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nSearch for files<|im_end|>\n"
"<|im_start|>assistant\n<think>\nLet me search.\n</think>\n\n"
"<tool_call>\n<function=search>\n</function>\n</tool_call><|im_end|>\n"
"<|im_start|>user\n<tool_response>\nNo files found\n</tool_response><|im_end|>\n";
std::string left = common;
std::string right = common + "<|im_start|>assistant\n<think>\n";
diff_split result = calculate_diff_split(left, right);
t.assert_equal("tool_response left", "", result.left);
t.assert_equal("tool_response right should be generation prompt",
"<|im_start|>assistant\n<think>\n", result.right);
}
}
static void test_until_common_prefix(testing & t) {
t.test("until_common_prefix basic", test_until_common_prefix_basic);
}
@ -1292,11 +1327,11 @@ static void test_nemotron_reasoning_detection(testing & t) {
// Check reasoning markers
t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start);
t.assert_equal("reasoning_end should be '</think>\\n'", "</think>\n", analysis.reasoning.end);
t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
// Check reasoning mode detection
// Nemotron uses forced closed reasoning with add_generation_prompt
t.assert_equal("reasoning should be FORCED_CLOSED", reasoning_mode::FORCED_CLOSED, analysis.reasoning.mode);
// Nemotron uses tag-based reasoning; prefill handles the template's forced markers
t.assert_equal("reasoning should be TAG_BASED", reasoning_mode::TAG_BASED, analysis.reasoning.mode);
// Make sure reasoning markers don't spill over to content markers
t.assert_equal("content start should be empty", "", analysis.content.start);

View File

@ -145,7 +145,7 @@ static void test_example_native(testing & t) {
common_reasoning_format reasoning_format;
json json_schema;
bool parallel_tool_calls;
bool thinking_forced_open;
std::string generation_prompt;
std::string input;
// Expect
@ -157,14 +157,8 @@ static void test_example_native(testing & t) {
auto build_parser = [](const test_case & tc) {
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
auto reasoning = p.eps();
if (tc.thinking_forced_open) {
// If thinking is forced open, expect a closing tag
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
} else {
// Otherwise, optionally accept thinking wrapped in tags
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
}
// Always use optional TAG_BASED pattern; generation_prompt is prepended to input
auto reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
// tool calling parser
if (tc.tools.is_array() && !tc.tools.empty()) {
@ -190,78 +184,91 @@ static void test_example_native(testing & t) {
std::vector<test_case> test_cases = std::vector<test_case>{
{
/* .name = */ "content with thinking_forced_open = false",
/* .name = */ "content with reasoning (no generation_prompt)",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false,
/* .generation_prompt = */ "",
/* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = false and no reasoning",
/* .name = */ "content without reasoning (no generation_prompt)",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false,
/* .generation_prompt = */ "",
/* .input = */ ("Hello"),
/* .expect_reasoning = */ "",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
/* .name = */ "content with reasoning_format = none (tags appear in content)",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "",
/* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "",
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = true",
/* .name = */ "content with reasoning generation_prompt",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "<think>",
/* .input = */ ("The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
/* .name = */ "content with reasoning generation_prompt and reasoning_format = none",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "",
/* .input = */ ("The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "",
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
/* .name = */ "content with closed reasoning generation_prompt (empty reasoning discarded)",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .generation_prompt = */ "<think></think>",
/* .input = */ ("Hello"),
/* .expect_reasoning = */ "",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "tools with reasoning generation_prompt",
/* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "<think>",
/* .input = */
("I must get the weather in New York</think>\n"
"<tool_call>["
@ -277,13 +284,13 @@ static void test_example_native(testing & t) {
} },
},
{
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
/* .name = */ "parallel tools with reasoning generation_prompt",
/* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ true,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "<think>",
/* .input = */
("I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me "
"search that for you."
@ -321,7 +328,7 @@ static void test_example_native(testing & t) {
} },
},
{
/* .name = */ "response_format with thinking_forced_open = true",
/* .name = */ "response_format with reasoning generation_prompt",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
@ -333,7 +340,7 @@ static void test_example_native(testing & t) {
{ "due_date", { { "type", "string" } } } } },
{ "required", { "invoice_number", "amount", "due_date" } } },
/* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true,
/* .generation_prompt = */ "<think>",
/* .input = */
("I must produce the invoice in the requested format</think>\n"
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"),
@ -361,7 +368,8 @@ static void test_example_native(testing & t) {
t.log(line);
}
common_peg_parse_context ctx(tc.input);
std::string effective_input = tc.generation_prompt + tc.input;
common_peg_parse_context ctx(effective_input);
auto result = parser.parse(ctx);
t.assert_true("success", result.success());

View File

@ -822,8 +822,7 @@ struct make_peg_parser {
}
common_chat_msg parse(const std::string & msg, bool is_partial) const {
common_chat_parser_params parser_params;
parser_params.format = params_.format;
common_chat_parser_params parser_params(params_);
parser_params.debug = detailed_debug_;
return common_chat_peg_parse(arena_, msg, is_partial, parser_params);
}
@ -996,6 +995,16 @@ static void test_peg_parser(common_chat_templates * tmpls,
grammar_triggered = true;
}
// For non-lazy grammars, prepend reasoning prefill to grammar input, just like
// PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
// <think>...</think>), but the model output may start mid-reasoning if the template
// already placed the opening tag in the prompt.
// For lazy grammars, the grammar only activates from the trigger position, so the
// reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
constrained = parser.params_.generation_prompt + constrained;
}
// Test the constrained portion against the grammar
if (grammar_triggered && !tc.is_partial) {
auto result = match_string_detailed(constrained, grammar.get());
@ -1271,11 +1280,13 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.expect(message_assist_thoughts)
.run();
tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
@ -1284,6 +1295,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"[THINK]I'm\nthinking[/THINK]"
R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
@ -1317,12 +1329,15 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// NVIDIA Nemotron-3 Nano
auto tst = peg_tester("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").enable_thinking(false).expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").
enable_thinking(false).
reasoning_format(COMMON_REASONING_FORMAT_AUTO).
expect(message_assist).run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.enable_thinking(false)
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
.expect_content("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.expect_content("<think>\nI'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
@ -1482,7 +1497,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect(simple_assist_msg("The answer is 42.", "Let me think about this..."))
.run();
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
tst.test("</think>Hello, world!").reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(simple_assist_msg("Hello, world!")).run();
}
{
// NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models)
@ -1798,6 +1813,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": "
"\"XYZCITY\"}<tool▁call▁end><tool▁calls▁end>")
.tools({ get_time_tool })
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
}
@ -1843,7 +1860,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug);
tst.test("CONTENT").expect(simple_assist_msg("CONTENT", "")).run();
tst.test("CONTENT").enable_thinking(false).reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).
expect(simple_assist_msg("CONTENT", "")).run();
}
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
@ -1906,6 +1924,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<arg_key>arg1</arg_key><arg_value>1</arg_value>"
"<arg_key>arg2</arg_key><arg_value>2</arg_value>"
"</tool_call>")
.enable_thinking(false)
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
@ -1915,6 +1934,79 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// #20650: tool with no required args, model emits <tool_call>name</tool_call> with no arg tags.
{
static common_chat_tool no_args_tool{
"read_file_diff_md", "Reads a file diff",
R"({"type":"object","properties":{"review_id":{"type":"string"},"file_id":{"type":"string"}}})",
};
tst.test(
"Let me read the diff content."
"</think>"
"<tool_call>read_file_diff_md</tool_call>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ no_args_tool })
.expect_reasoning("Let me read the diff content.")
.expect_tool_calls({{ "read_file_diff_md", "{}", {} }})
.run();
}
}
// Verify the throw path produces a readable error message, not std::out_of_range.
// #20424 introduced effective_input = generation_prompt + input, but the throw
// uses input.substr(result.end) where result.end is in effective_input space.
{
auto tmpls = common_chat_templates_ptr(
common_chat_templates_init(nullptr, read_file("models/templates/GLM-4.7-Flash.jinja")));
static common_chat_tool weather_tool{
"get_weather", "Get weather",
R"({"type":"object","properties":{"city":{"type":"string"}},"required":["city"]})",
};
common_chat_templates_inputs inputs;
inputs.tools = { weather_tool };
inputs.enable_thinking = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
inputs.add_generation_prompt = true;
inputs.use_jinja = true;
common_chat_msg msg;
msg.role = "user";
msg.content = "get_weather";
inputs.messages = { msg };
auto params = common_chat_templates_apply(tmpls.get(), inputs);
common_peg_arena arena;
arena.load(params.parser);
common_chat_parser_params pp(params);
// generation_prompt is non-empty for thinking models, so result.end
// will be offset by generation_prompt.size() into effective_input space.
assert(!pp.generation_prompt.empty());
std::string bad_input =
"Thinking.\n"
"</think>"
"<tool_call>get_weather"
"<arg_key>city</arg_key><arg_value>Tokyo</arg_value>"
"</tool_call>\n";
bool got_runtime_error = false;
bool got_out_of_range = false;
std::string error_msg;
try {
common_chat_peg_parse(arena, bad_input, /*is_partial=*/false, pp);
} catch (const std::out_of_range & e) {
got_out_of_range = true;
error_msg = e.what();
} catch (const std::runtime_error & e) {
got_runtime_error = true;
error_msg = e.what();
}
GGML_ASSERT(!got_out_of_range && "throw path crashed with out_of_range (input.substr in effective_input space)");
GGML_ASSERT(got_runtime_error && "throw path should produce std::runtime_error with parse position");
}
// Kimi-K2-Thinking tests - custom parser
@ -2222,10 +2314,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
tst.test(
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.run();
}
@ -2288,8 +2381,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Functionary v3.2 - recipient-based format: >>>recipient\n{content}
{
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug);
tst.test(">>>all\nHello, world!\nWhat's up?").expect(message_assist).run();
tst.test(">>>special_function\n{\"arg1\": 1}")
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run();
tst.test("special_function\n{\"arg1\": 1}")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
@ -2309,8 +2402,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag
{
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?")
.enable_thinking(true) // Forced open
tst.test("</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist)
.run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
@ -2322,14 +2415,15 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// llama-cpp DeepSeek R1 template (always forced-open thinking)
{
auto tst = peg_tester("models/templates/llama-cpp-deepseek-r1.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("</think>Hello, world!\nWhat's up?").expect(message_assist).reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts)
.run();
tst.test(
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"</think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ special_function_tool })
.parallel_tool_calls(true)
.expect(message_assist_call)
@ -2339,7 +2433,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag
{
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").enable_thinking(true).expect(message_assist).run();
tst.test("</think>Hello, world!\nWhat's up?").enable_thinking(true).
reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).
expect(message_assist).run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts)
@ -2348,6 +2444,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>")
.tools({ special_function_tool })
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_call)
.run();
}
@ -2377,12 +2475,12 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Apriel 1.6 Thinker (reasoning-only support)
{
auto tst = peg_tester("models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Implicit reasoning start (forced open)
tst.test("I'm\nthinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist_thoughts)
.enable_thinking(true)
.expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking"))
.run();
// Reasoning + Tool calls
@ -2390,8 +2488,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"I'm\nthinking\n[BEGIN FINAL RESPONSE]\n<tool_calls>[{\"name\": \"special_function\", \"arguments\": "
"{\"arg1\": 1}}]</tool_calls>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.expect(simple_assist_msg("", "Here are my reasoning steps:\nI'm\nthinking", "special_function", "{\"arg1\":1}"))
.run();
}

View File

@ -105,7 +105,7 @@ struct cli_context {
llama_get_model(ctx_server.get_llama_context()));
task.params.sampling.reasoning_budget_tokens = reasoning_budget;
task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open;
task.params.sampling.generation_prompt = chat_params.generation_prompt;
if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =

View File

@ -41,6 +41,11 @@ struct clip_graph {
virtual ~clip_graph() = default;
virtual ggml_cgraph * build() = 0;
// wrapper around ggml_mul_mat, allow hooking (e.g. LoRA, clamping) depending on the model
// tensor w should be the weight matrix, and tensor x should be the input
virtual ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const;
// TODO: build_mm(w, b, x) to support bias
//
// utility functions
//

View File

@ -255,6 +255,10 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
}
ggml_tensor * clip_graph::build_mm(ggml_tensor * w, ggml_tensor * x) const {
return ggml_mul_mat(ctx0, w, x);
}
void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
if (il >= 0) {
ggml_format_name(cur, "%s-%d", name, il);
@ -326,7 +330,7 @@ ggml_tensor * clip_graph::build_vit(
ggml_tensor * Vcur = nullptr;
if (layer.qkv_w != nullptr) {
// fused qkv
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = build_mm(layer.qkv_w, cur);
if (layer.qkv_b != nullptr) {
cur = ggml_add(ctx0, cur, layer.qkv_b);
}
@ -360,17 +364,17 @@ ggml_tensor * clip_graph::build_vit(
} else {
// separate q, k, v
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
@ -517,7 +521,7 @@ ggml_tensor * clip_graph::build_ffn(
ffn_op_type type_op,
int il) const {
ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
ggml_tensor * tmp = up ? build_mm(up, cur) : cur;
cb(tmp, "ffn_up", il);
if (up_b) {
@ -526,7 +530,7 @@ ggml_tensor * clip_graph::build_ffn(
}
if (gate) {
cur = ggml_mul_mat(ctx0, gate, cur);
cur = build_mm(gate, cur);
cb(cur, "ffn_gate", il);
if (gate_b) {
@ -580,7 +584,7 @@ ggml_tensor * clip_graph::build_ffn(
}
if (down) {
cur = ggml_mul_mat(ctx0, down, cur);
cur = build_mm(down, cur);
}
if (down_b) {
@ -646,7 +650,7 @@ ggml_tensor * clip_graph::build_attn(
cb(cur, "kqv_out", il);
if (wo) {
cur = ggml_mul_mat(ctx0, wo, cur);
cur = build_mm(wo, cur);
}
if (wo_b) {

View File

@ -19,7 +19,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
auto & layer = model.layers[il];
ggml_tensor * cur = inpL;
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = build_mm(layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
@ -67,7 +67,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
ggml_row_size(inpL->type, n_embd), 0);
// Multiply with mm_model_proj
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
cur = build_mm(model.mm_model_proj, cur);
// Apply layernorm, weight, bias
cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
@ -76,16 +76,16 @@ ggml_cgraph * clip_graph_cogvlm::build() {
cur = ggml_gelu_inplace(ctx0, cur);
// Branch 1: multiply with mm_h_to_4h_w
ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
ggml_tensor * h_to_4h = build_mm(model.mm_h_to_4h_w, cur);
// Branch 2: multiply with mm_gate_w
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
ggml_tensor * gate = build_mm(model.mm_gate_w, cur);
// Apply silu
gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
// Apply mm_4h_to_h_w
cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
cur = build_mm(model.mm_4h_to_h_w, gate);
// Concatenate with boi and eoi
cur = ggml_concat(ctx0, model.mm_boi, cur, 1);

View File

@ -56,7 +56,7 @@ ggml_cgraph * clip_graph_conformer::build() {
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// calculate out
cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
cur = build_mm(model.pre_encode_out_w, cur);
cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
cb(cur, "conformer.pre_encode.out", -1);
}
@ -87,7 +87,7 @@ ggml_cgraph * clip_graph_conformer::build() {
cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_self_att", il);
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
@ -96,12 +96,12 @@ ggml_cgraph * clip_graph_conformer::build() {
Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
// TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
@ -111,7 +111,7 @@ ggml_cgraph * clip_graph_conformer::build() {
matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
auto * p = build_mm(layer.linear_pos_w, pos_emb);
cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
p = ggml_permute(ctx0, p, 0, 2, 1, 3);
@ -143,7 +143,7 @@ ggml_cgraph * clip_graph_conformer::build() {
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
ggml_tensor * out = build_mm(layer.o_w, x);
out = ggml_add(ctx0, out, layer.o_b);
cb(out, "conformer.layers.{}.self_attn.linear_out", il);
@ -157,7 +157,7 @@ ggml_cgraph * clip_graph_conformer::build() {
// conv
{
auto * x = cur;
x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
x = build_mm(layer.conv_pw1_w, x);
x = ggml_add(ctx0, x, layer.conv_pw1_b);
cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
@ -181,7 +181,7 @@ ggml_cgraph * clip_graph_conformer::build() {
x = ggml_silu(ctx0, x);
// pointwise_conv2
x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
x = build_mm(layer.conv_pw2_w, x);
x = ggml_add(ctx0, x, layer.conv_pw2_b);
cur = x;

View File

@ -97,7 +97,7 @@ ggml_cgraph * clip_graph_glm4v::build() {
// FC projector
{
cur = ggml_mul_mat(ctx0, model.projection, cur);
cur = build_mm(model.projection, cur);
// default LayerNorm (post_projection_norm)
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
cur = ggml_gelu_erf(ctx0, cur);

View File

@ -22,7 +22,7 @@ ggml_cgraph * clip_graph_llama4::build() {
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
patch_size, patch_size, 3, n_embd);
inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
inp = build_mm(model.patch_embeddings_0, inp);
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
cb(inp, "patch_conv", -1);
}
@ -78,15 +78,15 @@ ggml_cgraph * clip_graph_llama4::build() {
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)
{
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
cur = build_mm(model.mm_model_mlp_1_w, cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
cur = build_mm(model.mm_model_mlp_2_w, cur);
cur = ggml_gelu(ctx0, cur);
cb(cur, "adapter_mlp", -1);
}
// Llama4MultiModalProjector
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
cur = build_mm(model.mm_model_proj, cur);
cb(cur, "projected", -1);
// build the graph

View File

@ -70,17 +70,17 @@ ggml_cgraph * clip_graph_llava::build() {
// self-attention
{
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
@ -164,17 +164,17 @@ ggml_cgraph * clip_graph_llava::build() {
// llava projector
if (proj_type == PROJECTOR_TYPE_MLP) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = build_mm(model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
embeddings = ggml_gelu(ctx0, embeddings);
if (model.mm_2_w) {
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = build_mm(model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
}
else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = build_mm(model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
// First LayerNorm
@ -186,7 +186,7 @@ ggml_cgraph * clip_graph_llava::build() {
embeddings = ggml_gelu(ctx0, embeddings);
// Second linear layer
embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
embeddings = build_mm(model.mm_3_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
// Second LayerNorm
@ -197,10 +197,10 @@ ggml_cgraph * clip_graph_llava::build() {
else if (proj_type == PROJECTOR_TYPE_LDP) {
// MobileVLM projector
int n_patch = 24;
ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
ggml_tensor * mlp_1 = build_mm(model.mm_model_mlp_1_w, embeddings);
mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
mlp_1 = ggml_gelu(ctx0, mlp_1);
ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
ggml_tensor * mlp_3 = build_mm(model.mm_model_mlp_3_w, mlp_1);
mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
// mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
@ -229,10 +229,10 @@ ggml_cgraph * clip_graph_llava::build() {
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
// pointwise conv
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
block_1 = build_mm(model.mm_model_block_1_block_1_fc1_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
block_1 = ggml_relu(ctx0, block_1);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
block_1 = build_mm(model.mm_model_block_1_block_1_fc2_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
block_1 = ggml_hardsigmoid(ctx0, block_1);
// block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
@ -244,7 +244,7 @@ ggml_cgraph * clip_graph_llava::build() {
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
block_1 = build_mm(model.mm_model_block_1_block_2_0_w, block_1);
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
// block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
@ -277,10 +277,10 @@ ggml_cgraph * clip_graph_llava::build() {
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
// pointwise conv
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
block_1 = build_mm(model.mm_model_block_2_block_1_fc1_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
block_1 = ggml_relu(ctx0, block_1);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
block_1 = build_mm(model.mm_model_block_2_block_1_fc2_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
block_1 = ggml_hardsigmoid(ctx0, block_1);
@ -292,7 +292,7 @@ ggml_cgraph * clip_graph_llava::build() {
block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
block_1 = build_mm(model.mm_model_block_2_block_2_0_w, block_1);
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
@ -307,10 +307,10 @@ ggml_cgraph * clip_graph_llava::build() {
else if (proj_type == PROJECTOR_TYPE_LDPV2)
{
int n_patch = 24;
ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
ggml_tensor * mlp_0 = build_mm(model.mm_model_mlp_0_w, embeddings);
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
mlp_0 = ggml_gelu(ctx0, mlp_0);
ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
ggml_tensor * mlp_2 = build_mm(model.mm_model_mlp_2_w, mlp_0);
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
// mlp_2 ne = [2048, 576, 1, 1]
// // AVG Pool Layer 2*2, strides = 2
@ -344,15 +344,15 @@ ggml_cgraph * clip_graph_llava::build() {
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
// GLU
{
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
embeddings = build_mm(model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings);
ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = build_mm(model.mm_model_mlp_2_w, embeddings);
x = build_mm(model.mm_model_mlp_1_w,x);
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
embeddings = build_mm(model.mm_model_mlp_3_w, embeddings);
}
// arrangement of BOI/EOI token embeddings
// note: these embeddings are not present in text model, hence we cannot process them as text tokens

View File

@ -38,7 +38,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
// resampler projector (it is just another transformer)
ggml_tensor * q = model.mm_model_query;
ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
ggml_tensor * v = build_mm(model.mm_model_kv_proj, embeddings);
// norm
q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
@ -77,13 +77,13 @@ ggml_cgraph * clip_graph_minicpmv::build() {
// Use actual config value if available, otherwise fall back to hardcoded values
int num_query = hparams.minicpmv_query_num;
ggml_tensor * Q = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
build_mm(model.mm_model_attn_q_w, q),
model.mm_model_attn_q_b);
ggml_tensor * K = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
build_mm(model.mm_model_attn_k_w, k),
model.mm_model_attn_k_b);
ggml_tensor * V = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
build_mm(model.mm_model_attn_v_w, v),
model.mm_model_attn_v_b);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
@ -105,7 +105,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
// projection
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
embeddings = build_mm(model.mm_model_proj, embeddings);
// build the graph
ggml_build_forward_expand(gf, embeddings);

View File

@ -429,7 +429,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() {
// PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
// Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
if (model.mm_input_proj_w) {
cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
cur = build_mm(model.mm_input_proj_w, cur);
}
// 5. POST PROJECTION NORM

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_pixtral::build() {
// project to n_embd
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
cur = build_mm(model.mm_patch_merger_w, cur);
}
// LlavaMultiModalProjector (always using GELU activation)

View File

@ -90,11 +90,11 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
// self-attention
{
ggml_tensor * Qcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
build_mm(layer.q_w, cur), layer.q_b);
ggml_tensor * Kcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
build_mm(layer.k_w, cur), layer.k_b);
ggml_tensor * Vcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
build_mm(layer.v_w, cur), layer.v_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);

View File

@ -85,7 +85,7 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
// self-attention
{
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = build_mm(layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() {
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur);
cur = build_mm(model.projection, cur);
} else if (proj_type == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block

Some files were not shown because too many files have changed in this diff Show More