This commit is contained in:
Nikhil Jain 2026-03-19 14:48:56 -07:00
commit 2c32cdc489
No known key found for this signature in database
119 changed files with 7009 additions and 7741 deletions

78
.github/workflows/ai-issues.yml vendored Normal file
View File

@ -0,0 +1,78 @@
name: AI review (issues)
on:
issues:
types: [opened]
jobs:
find-related:
if: github.event.action == 'opened'
runs-on: [self-hosted, opencode]
permissions:
contents: read
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 1
- name: Find related
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENCODE_PERMISSION: |
{
"bash": {
"*": "deny",
"gh issue*": "allow"
},
"webfetch": "deny"
}
run: |
rm AGENTS.md
rm CLAUDE.md
opencode run -m llama.cpp-dgx/ai-review-issues-find-similar --thinking "A new issue has been created:
Issue number: ${{ github.event.issue.number }}
Lookup the contents of the issue using the following command:
```bash
gh issue view ${{ github.event.issue.number }} --json title,body,url,number
```
Perform the following task and then post a SINGLE comment (if needed).
---
TASK : FIND RELATED ISSUES
Search through existing issues (excluding #${{ github.event.issue.number }}) to find related or similar issues.
Consider:
1. Similar titles or descriptions
2. Same error messages or symptoms
3. Related functionality or components
4. Similar feature requests
---
POSTING YOUR COMMENT:
Based on your findings, post a SINGLE comment on issue #${{ github.event.issue.number }}. Build the comment as follows:
If no related issues were found, do NOT comment at all.
If related issues were found, include a section listing them with links using the following format:
[comment]
This issue might be similar or related to:
- #[issue_number]: [brief description of how they are related]
_This comment was auto-generated locally using **$GA_ENGINE** on **$GA_MACHINE**_
[/comment]
Remember: Do not include the comment tags in your actual comment. Post at most ONE comment combining all findings. If everything is fine, post nothing.
"

80
.github/workflows/hip-quality-check.yml vendored Normal file
View File

@ -0,0 +1,80 @@
name: HIP quality check
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/hip-quality-check.yml',
'**/*.cu',
'**/*.cuh'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-22-hip-quality-check:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:7.2
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev python3
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: ubuntu-22-hip-quality-check
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with Werror
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=Off \
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc)
- name: Check for major VGPR spills
id: vgpr_check
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=On \
-DCMAKE_HIP_FLAGS="" \
-DCMAKE_BUILD_TYPE=Release
cd build
make -j $(nproc) 2>&1 | tee metrics.log | grep -v 'Rpass-analysis=kernel-resource-usage\|remark:\|^$'
python3 ../scripts/hip/gcn-cdna-vgpr-check.py metrics.log

View File

@ -1830,23 +1830,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--grammar"}, "GRAMMAR", {"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), "BNF-like grammar to constrain generations (see samples in grammars/ dir)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = value; params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--grammar-file"}, "FNAME", {"--grammar-file"}, "FNAME",
"file to read grammar from", "file to read grammar from",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = read_file(value); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"-j", "--json-schema"}, "SCHEMA", {"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.grammar = json_schema_to_grammar(json::parse(value)); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1863,7 +1863,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(schema) std::back_inserter(schema)
); );
params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))};
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -3494,7 +3494,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
throw std::invalid_argument("unknown speculative decoding type without draft model"); throw std::invalid_argument("unknown speculative decoding type without draft model");
} }
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg( add_opt(common_arg(
{"--spec-ngram-size-n"}, "N", {"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n), string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),

View File

@ -1,3 +1,4 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::function<void(const
namespace autoparser { namespace autoparser {
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) : parser_build_context::parser_build_context(common_chat_peg_builder & p, const generation_params & inputs) :
p(p), p(p),
inputs(inputs), inputs(inputs),
reasoning_parser(p.eps()) {} reasoning_parser(p.eps()) {}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs) { const struct generation_params & inputs) {
// Run differential analysis to extract template structure // Run differential analysis to extract template structure
struct autoparser autoparser; struct autoparser autoparser;
autoparser.analyze_template(tmpl); autoparser.analyze_template(tmpl);
@ -37,17 +38,16 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
} }
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs, const struct generation_params & inputs,
const autoparser & autoparser) { const autoparser & autoparser) {
// Build the parser using the analysis results
auto parser = autoparser.build_parser(inputs);
// Create the result structure // Create the result structure
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens; data.preserved_tokens = autoparser.preserved_tokens;
data.parser = parser.save();
auto parser = autoparser.build_parser(inputs);
data.parser = parser.save();
// Build grammar if tools are present // Build grammar if tools are present
bool has_tools = bool has_tools =
@ -82,44 +82,38 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
return data; return data;
} }
common_peg_arena autoparser::build_parser(const templates_params & inputs) const { common_peg_arena autoparser::build_parser(const generation_params & inputs) const {
if (!analysis_complete) { if (!analysis_complete) {
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
} }
return build_chat_peg_parser([&](common_chat_peg_builder & p) { return build_chat_peg_parser([&](common_chat_peg_builder & p) {
// If the template uses Python dict format (single-quoted strings in JSON structures),
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs); parser_build_context ctx(p, inputs);
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = inputs.enable_thinking;
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE; ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content; ctx.content = &content;
// Build reasoning parser // Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx); ctx.reasoning_parser = reasoning.build_parser(ctx);
auto parser = p.eps();
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty(); bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
if (has_response_format) { if (has_response_format) {
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return ctx.reasoning_parser + p.space() + p.choice({ parser = ctx.reasoning_parser + p.space() + p.choice({
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"), p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
response_format response_format
}) + p.end(); }) + p.end();
} else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
parser = tools.build_parser(ctx);
} else {
parser = content.build_parser(ctx);
} }
parser = wrap_for_generation_prompt(p, parser, inputs, reasoning.start);
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) { return parser;
return tools.build_parser(ctx);
}
return content.build_parser(ctx);
}); });
} }
@ -130,24 +124,15 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
return p.eps(); return p.eps();
} }
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) if (!end.empty()) {
// Both use the same tag-based pattern if markers are available if (!start.empty()) {
if (!start.empty() && !end.empty()) { // Standard tag-based: optional(<think>reasoning</think>)
return p.optional(start + p.reasoning(p.until(end)) + end); return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
}
// Delimiter-style (empty start)
return p.optional(p.reasoning(p.until(end)) + end + p.space());
} }
} else if (mode == reasoning_mode::DELIMITER) {
return p.optional(p.reasoning(p.until(end)) + end);
} }
return p.eps(); return p.eps();
@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
"tool-" + name + "-arg-" + param_name + "-schema", "tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) : param_schema, true)) :
p.tool_arg_json_value(p.schema( p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) + p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix))); p.tool_arg_close(p.literal(arguments.value_suffix)));
@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section) + p.space() + args_seq; call_id_section) + p.space() + args_seq;
matched_atomic = true; matched_atomic = true;
} else if (!arguments.name_prefix.empty() && properties.size() > 0) { } else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
// Only peek for an arg tag when there are required args that must follow.
// When all args are optional, the model may emit no arg tags at all (#20650).
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq; call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
matched_atomic = true; matched_atomic = true;

View File

@ -1,9 +1,11 @@
#include "chat-auto-parser-helpers.h" #include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
#include "log.h" #include "log.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "peg-parser.h"
#include <cctype> #include <cctype>
#include <numeric> #include <numeric>
@ -291,10 +293,26 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
return result; return result;
} }
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start) {
auto parser = prs;
if (!inputs.generation_prompt.empty()) {
size_t end_pos = inputs.generation_prompt.size();
if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) {
end_pos = inputs.generation_prompt.find(reasoning_start);
}
std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos);
parser = p.literal(cut_genprompt) + parser;
}
return parser;
}
namespace autoparser { namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) { std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
templates_params tmpl_params; generation_params tmpl_params;
tmpl_params.messages = params.messages; tmpl_params.messages = params.messages;
tmpl_params.tools = params.tools; tmpl_params.tools = params.tools;
tmpl_params.add_generation_prompt = params.add_generation_prompt; tmpl_params.add_generation_prompt = params.add_generation_prompt;

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "peg-parser.h"
#include <functional> #include <functional>
#include <optional> #include <optional>
#include <string> #include <string>
@ -57,6 +58,11 @@ std::vector<segment> segmentize_markers(const std::string & text);
// (MARKER, "</function>"), (MARKER, "</tool_call>") ] // (MARKER, "</function>"), (MARKER, "</tool_call>") ]
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments); std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
// Wrap parser with generation prompt parser
common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p,
const common_peg_parser & prs,
const autoparser::generation_params & inputs,
const std::string & reasoning_start = {});
namespace autoparser { namespace autoparser {
// Apply a template with the given parameters, returning the rendered string (empty on failure) // Apply a template with the given parameters, returning the rendered string (empty on failure)

View File

@ -50,7 +50,7 @@ namespace autoparser {
// High-level params for parser generation // High-level params for parser generation
// ============================================================================ // ============================================================================
struct templates_params { struct generation_params {
json messages; json messages;
json tools; json tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
@ -62,6 +62,7 @@ struct templates_params {
bool add_generation_prompt = false; bool add_generation_prompt = false;
bool enable_thinking = true; bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::string generation_prompt;
json extra_context; json extra_context;
bool add_bos = false; bool add_bos = false;
bool add_eos = false; bool add_eos = false;
@ -77,11 +78,7 @@ struct templates_params {
// Reasoning handling mode (derived from R1-R3 comparisons) // Reasoning handling mode (derived from R1-R3 comparisons)
enum class reasoning_mode { enum class reasoning_mode {
NONE, // No reasoning markers detected NONE, // No reasoning markers detected
TAG_BASED, // Standard tag-based: <think>...</think> TAG_BASED, // Tag-based: <think>...</think> (start can be empty for delimiter-style)
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
// with both opened and closed tag for disabled thinking
TOOLS_ONLY // Only reason on tool calls, not on normal content TOOLS_ONLY // Only reason on tool calls, not on normal content
}; };
@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode)
return os << "NONE"; return os << "NONE";
case reasoning_mode::TAG_BASED: case reasoning_mode::TAG_BASED:
return os << "TAG_BASED"; return os << "TAG_BASED";
case reasoning_mode::DELIMITER:
return os << "DELIMITER";
case reasoning_mode::FORCED_OPEN:
return os << "FORCED_OPEN";
case reasoning_mode::FORCED_CLOSED:
return os << "FORCED_CLOSED";
case reasoning_mode::TOOLS_ONLY: case reasoning_mode::TOOLS_ONLY:
return os << "TOOLS_ONLY"; return os << "TOOLS_ONLY";
default: default:
@ -184,7 +175,6 @@ struct tool_format_analysis {
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } } bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
std::string function_field = "function"; std::string function_field = "function";
std::string name_field = "name"; std::string name_field = "name";
@ -225,12 +215,12 @@ struct analyze_content;
struct parser_build_context { struct parser_build_context {
common_chat_peg_builder & p; common_chat_peg_builder & p;
const templates_params & inputs; const generation_params & inputs;
common_peg_parser reasoning_parser; common_peg_parser reasoning_parser;
bool extracting_reasoning = false; bool extracting_reasoning = false;
const analyze_content * content = nullptr; const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs); parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
}; };
// ============================================================================ // ============================================================================
@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base {
analyze_reasoning() = default; analyze_reasoning() = default;
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools); analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {}
common_peg_parser build_parser(parser_build_context & ctx) const override; common_peg_parser build_parser(parser_build_context & ctx) const override;
@ -381,7 +372,7 @@ struct autoparser {
void analyze_template(const common_chat_template & tmpl); void analyze_template(const common_chat_template & tmpl);
// Build the PEG parser for this template // Build the PEG parser for this template
common_peg_arena build_parser(const templates_params & inputs) const; common_peg_arena build_parser(const generation_params & inputs) const;
private: private:
// Collect tokens from entire analysis to preserve // Collect tokens from entire analysis to preserve
@ -395,10 +386,10 @@ struct autoparser {
class peg_generator { class peg_generator {
public: public:
static common_chat_params generate_parser(const common_chat_template & tmpl, static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs); const struct generation_params & inputs);
static common_chat_params generate_parser(const common_chat_template & tmpl, static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs, const struct generation_params & inputs,
const autoparser & autoparser); const autoparser & autoparser);
}; };

View File

@ -2,6 +2,7 @@
#include "chat-auto-parser-helpers.h" #include "chat-auto-parser-helpers.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "chat.h" #include "chat.h"
#include "common.h"
#include "log.h" #include "log.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "peg-parser.h" #include "peg-parser.h"
@ -31,8 +32,9 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
[](const common_chat_template & tmpl, autoparser & analysis) -> void { [](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("content.split('</think>')") != std::string::npos && if (tmpl.src.find("content.split('</think>')") != std::string::npos &&
tmpl.src.find("reasoning_content") == std::string::npos && tmpl.src.find("reasoning_content") == std::string::npos &&
tmpl.src.find("<SPECIAL_12>") == std::string::npos &&
analysis.reasoning.mode == reasoning_mode::NONE) { analysis.reasoning.mode == reasoning_mode::NONE) {
analysis.reasoning.mode = reasoning_mode::FORCED_OPEN; analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<think>"; analysis.reasoning.start = "<think>";
analysis.reasoning.end = "</think>"; analysis.reasoning.end = "</think>";
analysis.preserved_tokens.push_back("<think>"); analysis.preserved_tokens.push_back("<think>");
@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false");
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() {
} }
if (result.result.success()) { if (result.result.success()) {
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close mode = reasoning_mode::TAG_BASED;
mode = reasoning_mode::TAG_BASED;
} else {
mode = reasoning_mode::FORCED_CLOSED;
}
start = trim_whitespace(result.tags["pre"]); start = trim_whitespace(result.tags["pre"]);
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else if (!result.tags["post"].empty()) { } else if (!result.tags["post"].empty()) {
mode = reasoning_mode::DELIMITER; mode = reasoning_mode::TAG_BASED;
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} }
} }
} }
@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() {
const auto & diff = comparison->diff; const auto & diff = comparison->diff;
std::string left_trimmed = trim_whitespace(diff.left); std::string left_trimmed = trim_whitespace(diff.left);
std::string right_trimmed = trim_whitespace(diff.right);
if (left_trimmed.empty() && !diff.right.empty()) { if (left_trimmed.empty() && !diff.right.empty()) {
std::string right_trimmed = trim_whitespace(diff.right);
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
if (start.empty()) { if (start.empty()) {
start = right_trimmed; start = right_trimmed;
mode = reasoning_mode::FORCED_OPEN; mode = reasoning_mode::TAG_BASED;
}
}
} else if (right_trimmed.empty() && !diff.left.empty()) {
if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) {
if (end.empty()) {
auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A));
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
start = seg[seg.size() - 2].value;
}
end = left_trimmed;
mode = reasoning_mode::TAG_BASED;
} }
} }
} }
if (start.empty() && !end.empty()) { if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) {
mode = reasoning_mode::DELIMITER; mode = reasoning_mode::TAG_BASED;
}
// Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers,
// but enable_thinking=true produces only the start marker
if (!comparison->output_A.empty() && !comparison->output_B.empty()) {
auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(start) + p.space() + p.literal(end) + p.rest();
});
auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest();
});
if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() &&
parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) {
mode = reasoning_mode::FORCED_CLOSED;
} else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier
auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A);
if (result.result.success()) {
start = result.tags["pre"];
mode = reasoning_mode::FORCED_CLOSED;
}
}
}
if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close"
if (!diff.left.empty() && !diff.right.empty()) {
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
if (seg_A.size() == 1 && seg_B.size() == 1) {
mode = reasoning_mode::FORCED_CLOSED;
start = seg_B[0].value;
end = seg_A[0].value;
}
}
} }
} }
@ -426,14 +400,14 @@ void analyze_reasoning::compare_reasoning_scope() {
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) { if (result.result.success()) {
start = result.tags["pre"]; start = result.tags["pre"];
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else { } else {
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
}); });
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) { if (result.result.success()) {
end = result.tags["post"]; end = trim_trailing_whitespace(result.tags["post"]);
} else { } else {
LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
mode = reasoning_mode::NONE; mode = reasoning_mode::NONE;
@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
return; return;
} }
enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES }; auto in_json_haystack = [&haystack](const std::string & needle) -> bool {
auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style {
auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({ return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({
p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")),
p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) }); p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) });
}); });
auto result = parser.parse_anywhere_and_extract(haystack); auto result = parser.parse_anywhere_and_extract(haystack);
if (!result.result.success()) { return result.result.success();
return json_quote_style::NONE;
}
return result.tags.count("sq") && !result.tags["sq"].empty()
? json_quote_style::SINGLE_QUOTES
: json_quote_style::DOUBLE_QUOTES;
}; };
auto fun_quote = in_json_haystack(fun_name_needle); auto fun_quote = in_json_haystack(fun_name_needle);
auto arg_quote = in_json_haystack(arg_name_needle); auto arg_quote = in_json_haystack(arg_name_needle);
if (fun_quote != json_quote_style::NONE) { if (fun_quote) {
// no need to check further, we're in JSON land // no need to check further, we're in JSON land
format.mode = tool_format::JSON_NATIVE; format.mode = tool_format::JSON_NATIVE;
format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES); } else if (arg_quote) {
} else if (arg_quote != json_quote_style::NONE) {
format.mode = tool_format::TAG_WITH_JSON; format.mode = tool_format::TAG_WITH_JSON;
format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES);
} else { } else {
format.mode = tool_format::TAG_WITH_TAGGED; format.mode = tool_format::TAG_WITH_TAGGED;
} }

View File

@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
result.tool_calls.push_back(pending_tool_call.value()); result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset(); pending_tool_call.reset();
} }
// Discard whitespace-only reasoning content (e.g. from <think></think> prefill)
if (!result.reasoning_content.empty()) {
bool all_whitespace = true;
for (char c : result.reasoning_content) {
if (c != ' ' && c != '\n' && c != '\r' && c != '\t') {
all_whitespace = false;
break;
}
}
if (all_whitespace) {
result.reasoning_content.clear();
}
}
} }
void common_chat_peg_mapper::map(const common_peg_ast_node & node) { void common_chat_peg_mapper::map(const common_peg_ast_node & node) {

View File

@ -1,5 +1,6 @@
#include "chat.h" #include "chat.h"
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h" #include "chat-auto-parser.h"
#include "chat-peg-parser.h" #include "chat-peg-parser.h"
#include "common.h" #include "common.h"
@ -22,6 +23,7 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -760,7 +762,7 @@ static void foreach_parameter(const json &
std::string common_chat_template_direct_apply( std::string common_chat_template_direct_apply(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs, const autoparser::generation_params & inputs,
const std::optional<json> & messages_override, const std::optional<json> & messages_override,
const std::optional<json> & tools_override, const std::optional<json> & tools_override,
const std::optional<json> & additional_context) { const std::optional<json> & additional_context) {
@ -811,7 +813,7 @@ std::string common_chat_template_direct_apply(
} }
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
@ -876,8 +878,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
// Response format parser // Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences // Ministral wants to emit json surrounded by code fences
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```",
<< "```"; inputs, "[THINK]");
} }
// Tool call parser // Tool call parser
@ -897,12 +899,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls,
inputs, "[THINK]");
} }
// Content only parser // Content only parser
include_grammar = false; include_grammar = false;
return reasoning << p.content(p.rest()); return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]");
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -928,7 +931,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
} }
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
// Copy reasoning to the "thinking" field as expected by the gpt-oss template // Copy reasoning to the "thinking" field as expected by the gpt-oss template
@ -936,7 +939,9 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
for (auto msg : inputs.messages) { for (auto msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
msg["thinking"] = msg.at("reasoning_content"); msg["thinking"] = msg.at("reasoning_content");
msg.erase("content"); if (msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty()) {
msg.erase("content");
}
} }
adjusted_messages.push_back(msg); adjusted_messages.push_back(msg);
} }
@ -986,7 +991,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
p.literal("<|channel|>final") + constraint + p.literal("<|message|>") + p.literal("<|channel|>final") + constraint + p.literal("<|message|>") +
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format); return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format),
inputs, "<|channel|>");
} }
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
@ -1018,10 +1024,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call); return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call);
} }
return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)); return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)),
inputs, "<|channel|>");
} }
return final_msg | (any + p.zero_or_more(start + any) + start + final_msg); return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg),
inputs, "<|channel|>");
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1049,7 +1057,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} // Functionary v3.2 - uses recipient-based format: >>>recipient\n{content}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1070,13 +1078,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Build content parser for >>>all\n{content} // Build content parser for >>>all\n{content}
// When tools are present, content stops before the next ">>>" (tool call) // When tools are present, content stops before the next ">>>" (tool call)
// When no tools, content goes until end // When no tools, content goes until end
auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>")); auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>"));
auto content_until_end = p.literal(">>>all\n") + p.content(p.rest()); auto content_until_end = p.literal("all\n") + p.content(p.rest());
// If no tools or tool_choice is NONE, just parse content // If no tools or tool_choice is NONE, just parse content
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
// When no tools, just match the prefix and capture everything after // When no tools, just match the prefix and capture everything after
return content_until_end + p.end(); return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs);
} }
// Build tool call parsers for each available function // Build tool call parsers for each available function
@ -1088,7 +1096,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Tool format: >>>function_name\n{json_args} // Tool format: >>>function_name\n{json_args}
auto tool_parser = p.tool( auto tool_parser = p.tool(
p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) + p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) +
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
); );
@ -1099,17 +1107,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice)); auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice));
auto content_and_tools = content_until_tool + tools_only; auto content_and_tools = content_until_tool + tools_only;
auto ret = p.eps();
if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) {
if (inputs.parallel_tool_calls) { if (inputs.parallel_tool_calls) {
return p.choice({ content_and_tools, tools_only }) + p.end(); ret = p.choice({ content_and_tools, tools_only }) + p.end();
} else {
ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end();
} }
return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); } else if (inputs.parallel_tool_calls) {
ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end();
} else {
auto content_and_tool = content_until_tool + tool_choice;
ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
} }
if (inputs.parallel_tool_calls) { return wrap_for_generation_prompt(p, ret, inputs);
return p.choice({ content_and_tools, content_only, tools_only }) + p.end();
}
auto content_and_tool = content_until_tool + tool_choice;
return p.choice({ content_and_tool, content_only, tool_choice }) + p.end();
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1139,14 +1150,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index> // Kimi K2 Thinking - uses unique tool call ID format: functions.<name>:<index>
// The ID contains both the function name and an incrementing counter // The ID contains both the function name and an incrementing counter
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true; data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = { data.preserved_tokens = {
"<|tool_calls_section_begin|>", "<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>", "<|tool_calls_section_end|>",
@ -1161,6 +1170,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Kimi K2 Thinking format: // Kimi K2 Thinking format:
// - Reasoning: <think>{reasoning}</think> // - Reasoning: <think>{reasoning}</think>
@ -1172,16 +1193,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// <|tool_calls_section_end|> // <|tool_calls_section_end|>
// The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ... // The ID format is: functions.<function_name>:<counter> where counter is 0, 1, 2, ...
// Tool call markers // Tool call markers
const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>";
const std::string SECTION_END = "<|tool_calls_section_end|>";
const std::string CALL_BEGIN = "<|tool_call_begin|>";
const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>";
const std::string CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
auto end = p.end(); auto end = p.end();
// Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny. // Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny.
@ -1193,7 +1205,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// Content only parser (no tools) // Content only parser (no tools)
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end; return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end,
inputs, THINK_START);
} }
// Build tool call parsers for each available function // Build tool call parsers for each available function
@ -1229,7 +1242,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN }));
return reasoning + content_before_tools + tool_calls + end; return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end,
inputs, THINK_START);
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1259,7 +1273,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|> // - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
// Tool calls can appear multiple times (parallel tool calls) // Tool calls can appear multiple times (parallel tool calls)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
@ -1278,13 +1292,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string TOOL_CALL_START = "<|tool_call_start|>"; const std::string TOOL_CALL_START = "<|tool_call_start|>";
const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string TOOL_CALL_END = "<|tool_call_end|>";
const std::string THINK_START = "<think>"; const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>"; const std::string THINK_END = "</think>";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto end = p.end(); auto end = p.end();
auto reasoning = p.eps(); auto reasoning = p.eps();
@ -1293,7 +1309,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
} }
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return reasoning + p.content(p.rest()) + end; return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs,
THINK_START);
} }
auto tool_calls = p.rule("tool-calls", auto tool_calls = p.rule("tool-calls",
@ -1305,7 +1322,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto content = p.content(p.until(TOOL_CALL_START)); auto content = p.content(p.until(TOOL_CALL_START));
return reasoning + content + tool_calls + end; return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs,
THINK_START);
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1331,7 +1349,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
static common_chat_params common_chat_params_init_gigachat_v3( static common_chat_params common_chat_params_init_gigachat_v3(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs) { const autoparser::generation_params & inputs) {
common_chat_params data; common_chat_params data;
@ -1345,9 +1363,10 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto ret = p.eps();
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// Build a choice of all available tools // Build a choice of all available tools
auto tool_choice = p.choice(); auto tool_choice = p.choice();
@ -1370,13 +1389,14 @@ static common_chat_params common_chat_params_init_gigachat_v3(
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice); auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls; ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
} else {
// Content only parser
include_grammar = false;
ret = p.content(p.rest());
} }
// Content only parser return wrap_for_generation_prompt(p, ret, inputs);
include_grammar = false;
return p.content(p.rest());
}); });
data.parser = parser.save(); data.parser = parser.save();
@ -1471,87 +1491,10 @@ static json common_chat_extra_context() {
return ctx; return ctx;
} }
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, static std::optional<common_chat_params> try_specialized_template(
const struct common_chat_templates_inputs & inputs) { const common_chat_template & tmpl,
autoparser::templates_params params; const std::string & src,
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const autoparser::generation_params & params) {
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
// params.parallel_tool_calls = false;
// } else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
//}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) {
return p.content(p.rest());
});
data.parser = parser.save();
return data;
}
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
@ -1592,14 +1535,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
// GigaChatV3 format detection // GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos && if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos && src.find("<|message_sep|>") != std::string::npos &&
src.find("<|function_call|>") == std::string::npos src.find("<|function_call|>") == std::string::npos) {
) {
LOG_DBG("Using specialized template: GigaChatV3\n"); LOG_DBG("Using specialized template: GigaChatV3\n");
return common_chat_params_init_gigachat_v3(tmpl, params); return common_chat_params_init_gigachat_v3(tmpl, params);
} }
return std::nullopt;
}
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs) {
autoparser::generation_params params;
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
const auto & tmpl =
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
if (tmpl.original_caps().supports_tool_calls) {
// some templates will require the content field in tool call messages
// to still be non-null, this puts an empty string everywhere where the
// content field is null
workaround::requires_non_null_content(params.messages);
}
if (tmpl.original_caps().supports_object_arguments) {
workaround::func_args_not_string(params.messages);
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;
params.add_generation_prompt = inputs.add_generation_prompt;
params.extra_context = common_chat_extra_context();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
}
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
params.parallel_tool_calls = inputs.parallel_tool_calls;
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
LOG_WRN(
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
}
}
if (inputs.force_pure_content) {
LOG_WRN("Forcing pure content template, will not render reasoning or tools separately.");
// Create the result structure
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
return wrap_for_generation_prompt(p, p.content(p.rest()), params);
});
data.parser = parser.save();
return data;
}
if (auto result = try_specialized_template(tmpl, src, params)) {
result->generation_prompt = params.generation_prompt;
return *result;
}
try { try {
LOG_DBG("Using differential autoparser\n"); LOG_DBG("%s: using differential autoparser\n", __func__);
struct autoparser::autoparser autoparser; struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl); autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
@ -1607,13 +1641,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
if (auto_params.supports_thinking) { if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start; auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end; auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
} }
auto_params.generation_prompt = params.generation_prompt;
common_peg_arena arena;
arena.load(auto_params.parser);
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
return auto_params; return auto_params;
} catch (const std::exception & e) { } catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@ -1711,14 +1743,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
LOG_DBG("No parser definition detected, assuming pure content parser."); LOG_DBG("No parser definition detected, assuming pure content parser.");
} }
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); const std::string effective_input = params.generation_prompt.empty()
? input
: params.generation_prompt + input;
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT; common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
if (params.debug) { if (params.debug) {
flags |= COMMON_PEG_PARSE_FLAG_DEBUG; flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
} }
common_peg_parse_context ctx(input, flags); common_peg_parse_context ctx(effective_input, flags);
auto result = parser.parse(ctx); auto result = parser.parse(ctx);
if (result.fail()) { if (result.fail()) {

View File

@ -24,7 +24,7 @@ using json = nlohmann::ordered_json;
struct common_chat_templates; struct common_chat_templates;
namespace autoparser { namespace autoparser {
struct templates_params; struct generation_params;
} // namespace autoparser } // namespace autoparser
struct common_chat_tool_call { struct common_chat_tool_call {
@ -212,7 +212,7 @@ struct common_chat_params {
std::string prompt; std::string prompt;
std::string grammar; std::string grammar;
bool grammar_lazy = false; bool grammar_lazy = false;
bool thinking_forced_open = false; std::string generation_prompt;
bool supports_thinking = false; bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>" std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>" std::string thinking_end_tag; // e.g., "</think>"
@ -229,14 +229,14 @@ struct common_chat_parser_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false; bool reasoning_in_content = false;
bool thinking_forced_open = false; std::string generation_prompt;
bool parse_tool_calls = true; bool parse_tool_calls = true;
bool debug = false; // Enable debug output for PEG parser bool debug = false; // Enable debug output for PEG parser
common_peg_arena parser = {}; common_peg_arena parser = {};
common_chat_parser_params() = default; common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) { common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format; format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open; generation_prompt = chat_params.generation_prompt;
} }
}; };
@ -302,7 +302,7 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
std::string common_chat_template_direct_apply( std::string common_chat_template_direct_apply(
const common_chat_template & tmpl, const common_chat_template & tmpl,
const autoparser::templates_params & inputs, const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt, const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt, const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt); const std::optional<json> & additional_context = std::nullopt);

View File

@ -3,12 +3,14 @@
#pragma once #pragma once
#include "ggml-opt.h" #include "ggml-opt.h"
#include "ggml.h"
#include "llama-cpp.h" #include "llama-cpp.h"
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <variant>
#include <vector> #include <vector>
#include <map> #include <map>
@ -178,6 +180,43 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
}; };
// Grammar type enumeration
enum common_grammar_type {
COMMON_GRAMMAR_TYPE_NONE, // no grammar set
COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field)
COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field)
COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling
};
// Grammar variant struct with type and grammar string
struct common_grammar {
common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
std::string grammar;
// Default constructor - no grammar
common_grammar() = default;
// Constructor with type and grammar string
common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
}
// Check if a grammar is set
bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};
// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
return g.grammar;
}
// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
|| g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}
// sampling parameters // sampling parameters
struct common_params_sampling { struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -228,7 +267,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE, COMMON_SAMPLER_TYPE_TEMPERATURE,
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls)
bool grammar_lazy = false; bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars) std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens; std::set<llama_token> preserved_tokens;
@ -236,10 +275,15 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// The assistant generation prompt already prefilled into the prompt.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
// Only applied when the grammar is of output-format or tool-calls type.
std::string generation_prompt;
// reasoning budget sampler parameters // reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params // these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag) std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)

View File

@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->force_pos = 0; ctx->force_pos = 0;
} }
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init( return common_reasoning_budget_init_state(
ctx->vocab, ctx->vocab,
ctx->start_matcher.tokens, ctx->start_matcher.tokens,
ctx->end_matcher.tokens, ctx->end_matcher.tokens,
@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = {
/* .backend_set_input = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * common_reasoning_budget_init( static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, int32_t budget,
common_reasoning_budget_state initial_state) { common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING // promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING; initial_state = REASONING_BUDGET_FORCING;
@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init(
} }
); );
} }
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens) {
// Determine initial state from prefill: COUNTING if the prefill begins with
// the start sequence but does not also contain the end sequence after it.
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE;
if (!prefill_tokens.empty() && !start_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() &&
std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) {
initial_state = REASONING_BUDGET_COUNTING;
// If the end sequence also follows the start in the prefill, reasoning
// was opened and immediately closed — stay IDLE.
if (!end_tokens.empty() &&
prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) {
auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size();
if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() &&
std::equal(end_tokens.begin(), end_tokens.end(), end_start)) {
initial_state = REASONING_BUDGET_IDLE;
}
}
}
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}

View File

@ -24,14 +24,26 @@ enum common_reasoning_budget_state {
// DONE: passthrough forever // DONE: passthrough forever
// //
// Parameters: // Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) // vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting // start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation // end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires // forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block // budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) // prefill_tokens - tokens already present in the prompt (generation prompt);
// note: COUNTING with budget <= 0 is promoted to FORCING // used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
// //
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
struct llama_sampler * common_reasoning_budget_init( struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab, const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens, const std::vector<llama_token> & start_tokens,

View File

@ -1,13 +1,16 @@
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "ggml.h"
#include "log.h" #include "log.h"
#include "reasoning-budget.h" #include "reasoning-budget.h"
#include <algorithm> #include <algorithm>
#include <cctype>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <unordered_map> #include <unordered_map>
#include <vector>
// the ring buffer works similarly to std::deque, but with a fixed capacity // the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h // TODO: deduplicate with llama-impl.h
@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
std::vector<llama_sampler *> samplers; std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { const std::string & grammar_str = common_grammar_value(params.grammar);
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
#else #else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
trigger_patterns_c.push_back(regex.c_str()); trigger_patterns_c.push_back(regex.c_str());
} }
if (!params.grammar.empty()) { if (!grammar_str.empty()) {
if (params.grammar_lazy) { if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()); trigger_tokens.data(), trigger_tokens.size());
} else { } else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
} }
} }
} }
// Feed generation prompt tokens to the grammar sampler so it advances past
// tokens the template already placed in the prompt.
// Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled.
std::vector<llama_token> prefill_tokens;
if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) {
GGML_ASSERT(vocab != nullptr);
prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true);
if (!prefill_tokens.empty()) {
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) {
// Some tokenizers will add a space before the first special token, need to remove
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
}
}
if (grmr) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
}
} catch (std::exception &e) {
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
throw e;
}
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers // reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init( samplers.push_back(common_reasoning_budget_init(
@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
params.reasoning_budget_end, params.reasoning_budget_end,
params.reasoning_budget_forced, params.reasoning_budget_forced,
params.reasoning_budget_tokens, params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); prefill_tokens));
} }
if (params.has_logit_bias()) { if (params.has_logit_bias()) {

View File

@ -1062,6 +1062,10 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}") logger.info(f"gguf: key-value head count = {n_head_kv}")
if self.hparams.get("is_causal") is False:
self.gguf_writer.add_causal_attention(False)
logger.info("gguf: causal attention = False")
# TODO: Handle "sliding_attention" similarly when models start implementing it # TODO: Handle "sliding_attention" similarly when models start implementing it
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
if (rope_type := rope_params.get("rope_type")) is not None: if (rope_type := rope_params.get("rope_type")) is not None:

View File

@ -14,7 +14,7 @@ The unified auto-parser uses a pure differential, compositional approach (inspir
**Analysis + Parser Building in Two Steps**: **Analysis + Parser Building in Two Steps**:
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs 1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar 2. `autoparser::peg_generator::generate_parser(tmpl, generation_params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
## Data Structures ## Data Structures
@ -34,7 +34,7 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
### `analyze_tools` and its sub-structs ### `analyze_tools` and its sub-structs
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`) - [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`)
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names - [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator` - [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values - [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
@ -47,12 +47,21 @@ All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h
| Value | Description | | Value | Description |
|-----------------|-----------------------------------------------------------------------------------| |-----------------|-----------------------------------------------------------------------------------|
| `NONE` | No reasoning markers detected | | `NONE` | No reasoning markers detected |
| `TAG_BASED` | Standard tag-based: `<think>...</think>` | | `TAG_BASED` | Tag-based: `<think>...</think>` (start can be empty for delimiter-style formats) |
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content | | `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
**Generation Prompt & Reasoning Prefill**: Computed in `common_chat_templates_apply_jinja` before invoking either the specialized handlers or the auto-parser, by rendering the template twice — once with `add_generation_prompt=false` and once with `add_generation_prompt=true` — and storing the diff suffix as `generation_params::generation_prompt`. This string is propagated into `common_chat_params::generation_prompt` and `common_chat_parser_params::generation_prompt`.
The generation prompt is prepended to model output before PEG parsing via `wrap_for_generation_prompt()`. The portion *before* the reasoning start marker (if any) is prepended as a literal to ensure any boilerplate added by the template is consumed. The full string is also fed to the grammar sampler via `llama_sampler_accept` (stored in `common_params_sampling::grammar_prefill`), advancing the grammar past tokens already in the prompt. It is used to determine the reasoning budget sampler's initial state — COUNTING if the prefill tokens begin with the reasoning start sequence (but don't also contain the end sequence), IDLE otherwise.
**`grammar_prefill`** (`common_params_sampling`): The generation prompt string tokenized and accepted by the grammar sampler at init time. Only applied when `grammar_external` is false (i.e., the grammar was not set explicitly by the user).
Three outcomes for reasoning-prefill handling (in `generate_parser()`):
1. **Start+end in generation prompt** (e.g. `<think></think>\n`): the parser sees reasoning as opened and immediately closed; whitespace-only reasoning content is discarded.
2. **Only start in generation prompt** (e.g. `<think>\n`): the parser sees reasoning as already open.
3. **Start marker present but not at the end** (e.g. Apriel's `<|begin_assistant|>` followed by boilerplate): the marker is a template artifact; the start literal is cleared so reasoning uses delimiter-style (end-only). For templates that ignore `add_generation_prompt` (empty diff), the rendered `data.prompt` is used as fallback — but only for non-TOOLS_ONLY modes, since in TOOLS_ONLY the start tag is model-generated and may appear in prior conversation turns.
**`content_mode`**: How the template wraps assistant content. **`content_mode`**: How the template wraps assistant content.
| Value | Description | | Value | Description |
@ -261,16 +270,16 @@ Text is segmentized into markers and non-marker fragments using `segmentize_mark
- Searches `diff.right` (output with reasoning) for the reasoning content needle - Searches `diff.right` (output with reasoning) for the reasoning content needle
- Uses PEG parsers to find surrounding markers: - Uses PEG parsers to find surrounding markers:
- If both pre/post markers found in `diff.right``TAG_BASED` (both tags visible in diff = no forced close) - If both pre/post markers found in `diff.right``TAG_BASED`
- If both found but post marker only in the full output B → `FORCED_CLOSED` - If both found but post marker only in the full output B → `TAG_BASED` (template forces markers; handled via prefill)
- If only post marker found → `DELIMITER` - If only post marker found → `TAG_BASED` (delimiter-style, empty start)
- Sets `reasoning.start` and `reasoning.end` - Sets `reasoning.start` and `reasoning.end`
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt. **R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN` - Detects template-added reasoning markers: `enable_thinking=true` appends a non-empty marker → sets `reasoning.start`, mode = `TAG_BASED`
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker - Handles the reverse case (`enable_thinking=false` appends the marker instead): extracts both start (from the preceding segment) and end markers; mode = `TAG_BASED`
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers - The reasoning prefill (markers added by the template) is later extracted in `common_chat_templates_apply_jinja` and prepended to model output before parsing
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls. **R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
@ -343,7 +352,7 @@ Classification logic:
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds: A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected 1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')` but not `<SPECIAL_12>`: sets `reasoning.mode = TAG_BASED` with `<think>`/`</think>` markers if no reasoning was detected
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>` 2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set 3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers 4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
@ -355,12 +364,13 @@ Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) i
#### Reasoning Parser (`analyze_reasoning::build_parser`) #### Reasoning Parser (`analyze_reasoning::build_parser`)
| Mode | Parser | | Mode | Parser |
|-----------------------------------|---------------------------------------------------------------------| |-----------------------------------------------|---------------------------------------------------------------------------|
| Not extracting reasoning | `eps()` | | Not extracting reasoning | `eps()` |
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt | | `TAG_BASED` or `TOOLS_ONLY` (non-empty start) | `optional(start + reasoning(until(end)) + end + space())` |
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` | | `TAG_BASED` or `TOOLS_ONLY` (empty start) | `optional(reasoning(until(end)) + end + space())` — delimiter-style |
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
Note: The start marker may be empty either because the analyzer detected delimiter-style reasoning, or because `generate_parser()` cleared a template artifact start marker (see Generation Prompt & Reasoning Prefill above). Whitespace-only reasoning content (e.g. from a `<think></think>` prefill) is discarded by the mapper.
#### Content Parser (`analyze_content::build_parser`) #### Content Parser (`analyze_content::build_parser`)
@ -410,9 +420,7 @@ All three tool parsers return:
reasoning + optional(content(until(trigger_marker))) + tool_calls + end() reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
``` ```
### Python Dict Format Each returned parser is wrapped by `wrap_for_generation_prompt()`, which prepends a literal for any boilerplate prefix of the generation prompt (the portion before the reasoning start marker).
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
## Mapper ## Mapper
@ -421,22 +429,22 @@ When `format.uses_python_dicts` is true (detected when single-quoted strings app
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments` - **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching - **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format) - **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`) - **Whitespace-only reasoning**: Reasoning content that consists entirely of whitespace (e.g. from a `<think></think>` prefill) is cleared so the message shows no reasoning
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically - **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
## Files ## Files
| File | Purpose | | File | Purpose |
|-------------------------------------------|----------------------------------------------------------------------| |-------------------------------------------|---------------------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` | | `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `generation_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods | | `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds | | `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, | | `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, `compare_variants()`, |
| | `compare_variants()`, string helpers | | | `wrap_for_generation_prompt()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers | | `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` | | `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis | | `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool | | `tools/parser/template-analysis.cpp` | Template analysis tool |
## Testing & Debugging ## Testing & Debugging
@ -516,10 +524,10 @@ To support a new template format:
## Edge Cases and Quirks ## Edge Cases and Quirks
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker. 1. **Generation Prompt & Reasoning Prefill**: The generation prompt is extracted by diffing `add_generation_prompt=false` vs `true` in `common_chat_templates_apply_jinja`, so it contains exactly what the template appends — avoiding false positives from prior conversation turns.
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker. 2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built. 3. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction. 4. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case. 5. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`. 6. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats. 7. **Undetected Tool Format**: If `analyze_tools` concludes tool calling is supported but cannot determine the format, `build_parser()` logs an error and returns `eps()` (graceful degradation) rather than aborting.

View File

@ -28,6 +28,9 @@ Additionally, there the following images, similar to the above:
- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:full-openvino`: Same as `full` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:light-openvino`: Same as `light` but compiled with OpenVino support. (platforms: `linux/amd64`)
- `ghcr.io/ggml-org/llama.cpp:server-openvino`: Same as `server` but compiled with OpenVino support. (platforms: `linux/amd64`)
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).

View File

@ -37,7 +37,7 @@ Legend:
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ | | DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -47,7 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | | GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -62,7 +62,7 @@ Legend:
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ | | L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | | LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -91,7 +91,7 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | | ❌ | ❌ | | SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -101,10 +101,10 @@ Legend:
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ | | SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -115,7 +115,7 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ | | TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

View File

@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
end = 2 * ((n_head - 1) - n_head_log2) + 1; end = 2 * ((n_head - 1) - n_head_log2) + 1;
step = 2; step = 2;
count = n_head - n_head_log2; count = n_head - n_head_log2;
aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
dtype); step, dtype);
} }
} }
@ -2943,6 +2943,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// Rotate full tensor (no tail), using trans tensors // Rotate full tensor (no tail), using trans tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
} else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
// In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
// read and write the same non-contiguous buffer. Use contiguous temporaries.
size_t contiguous_nb[GGML_MAX_DIMS];
contiguous_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
}
int64_t total_elements = ggml_nelements(src0);
ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
src0->ne, contiguous_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
dst->ne, contiguous_nb, GGML_MAX_DIMS);
cann_copy(ctx, acl_src.get(), acl_src_contig.get());
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
} else { } else {
// Rotate full tensor (no tail), using original tensors // Rotate full tensor (no tail), using original tensors
GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
@ -3599,6 +3620,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
// Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
// (required by FusedInferAttentionScoreV2)
const int64_t D = src0->ne[0];
const int64_t D_padded = GGML_PAD(D, 16);
const bool needs_padding = (D != D_padded);
ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
if (needs_padding) {
int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
ggml_cann_pool_alloc & allocator) {
int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
size_t pad_nb[GGML_MAX_DIMS];
pad_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
}
int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
void * buffer = allocator.alloc(nelements * faElemSize);
acl_tensor_ptr padded =
ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
tensor = std::move(padded);
};
pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
src0_bsnd_ne[0] = D_padded;
src1_bsnd_ne[0] = D_padded;
src2_bsnd_ne[0] = D_padded;
}
// Step 3: create the PSEShift tensor if needed // Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp // this tensor is considered as mask (f16) in the llama.cpp
acl_tensor_ptr bcast_pse_tensor; acl_tensor_ptr bcast_pse_tensor;
@ -3688,17 +3747,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
acl_tensor_ptr fa_dst_tensor; acl_tensor_ptr fa_dst_tensor;
acl_tensor_ptr acl_dst_tensor;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) { if (dst->type == GGML_TYPE_F32 || needs_padding) {
void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
int64_t * out_f16_ne = src0_bsnd_ne; int64_t * out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS]; size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize; out_f16_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) { for (int i = 1; i < GGML_MAX_DIMS; ++i) {
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
} }
int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
fa_dst_tensor = fa_dst_tensor =
ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@ -3730,8 +3788,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
nullptr // softmaxLse nullptr // softmaxLse
); );
if (dst->type == GGML_TYPE_F32) { // Step 6: post-processing — slice padded output and/or cast to f32
// Step 6: post-processing, permute and cast to f32 if (needs_padding) {
ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
size_t sliced_nb[GGML_MAX_DIMS];
sliced_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
}
int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
sliced_ne, sliced_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
} else {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
(int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
}
} else if (dst->type == GGML_TYPE_F32) {
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
} }

View File

@ -2503,10 +2503,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// different head sizes of K and V are not supported yet // different head sizes of K and V are not supported yet
return false; return false;
} }
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
}
float logitSoftcap = 0.0f; float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
if (logitSoftcap != 0.0f) { if (logitSoftcap != 0.0f) {

View File

@ -570,24 +570,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
if (POLICY CMP0135) set(KLEIDIAI_FETCH_ARGS
cmake_policy(SET CMP0135 NEW) URL ${KLEIDIAI_DOWNLOAD_URL}
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
endif() endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28")
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 FetchContent_Declare(KleidiAI_Download
FetchContent_Declare(KleidiAI_Download ${KLEIDIAI_FETCH_ARGS}
URL ${KLEIDIAI_DOWNLOAD_URL} EXCLUDE_FROM_ALL
DOWNLOAD_EXTRACT_TIMESTAMP NEW )
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_GetProperties(KleidiAI_Download FetchContent_MakeAvailable(KleidiAI_Download)
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
else()
FetchContent_Declare(KleidiAI_Download
${KLEIDIAI_FETCH_ARGS}
)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED
)
if (NOT KLEIDIAI_POPULATED)
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
endif() endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI) add_compile_definitions(GGML_USE_CPU_KLEIDIAI)

View File

@ -45,6 +45,7 @@ static int opt_verbose = 0;
static int opt_profile = 0; static int opt_profile = 0;
static int opt_hostbuf = 1; // hostbuf ON by default static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_experimental = 0; static int opt_experimental = 0;
static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only
// Enable all stages by default // Enable all stages by default
static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
// Start the DSP-side service. We need to pass the queue ID to the // Start the DSP-side service. We need to pass the queue ID to the
// DSP in a FastRPC call; the DSP side will import the queue and start // DSP in a FastRPC call; the DSP side will import the queue and start
// listening for packets in a callback. // listening for packets in a callback.
err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx);
if (err != 0) { if (err != 0) {
GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_etm = getenv("GGML_HEXAGON_ETM");
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH"); const char * str_arch = getenv("GGML_HEXAGON_ARCH");
@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
opt_opsync = str_opsync ? atoi(str_opsync) : 0; opt_opsync = str_opsync ? atoi(str_opsync) : 0;
opt_profile = str_profile ? atoi(str_profile) : 0; opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0; opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {

View File

@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,> $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-matmul-ops.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB}) build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)

View File

@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity; return q->capacity;
} }
// ---------------------------------------------------------------------------
// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper
// transparently handles values that exceed the 16-bit limit and submits
// chained DMA transtions.
//
// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
// Case 2 (contiguous block): width == srcstride == dststride. Reshape the
// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a
// single descriptor, preserving async DMA behavior.
// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows
// one at a time. The first N-1 rows are pushed+popped synchronously;
// the last row is left async so the caller can pop it.
// ---------------------------------------------------------------------------
#define UDMA_MAX_FIELD_VAL 65535u
static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
// Fast path: everything fits in 16 bits.
if (__builtin_expect(
width <= UDMA_MAX_FIELD_VAL &&
nrows <= UDMA_MAX_FIELD_VAL &&
src_stride <= UDMA_MAX_FIELD_VAL &&
dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
}
// Case 2: contiguous block (width == src_stride == dst_stride).
// Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
if (width == src_stride && width == dst_stride) {
size_t total = width * nrows;
// Pick the largest 128-byte-aligned sub_width that divides total evenly.
size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408
while (sub_width > 0 && total % sub_width != 0) {
sub_width -= 128;
}
if (sub_width == 0) {
// Fallback: use original width (must fit) with adjusted nrows.
// This shouldn't happen for 128-aligned DMA sizes.
sub_width = width;
}
size_t sub_nrows = total / sub_width;
// Handle sub_nrows > 65535 by issuing chunked descriptors.
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
size_t rows_done = 0;
while (rows_done < sub_nrows) {
size_t chunk = sub_nrows - rows_done;
if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
return false;
rows_done += chunk;
// Complete all chunks without waiting except the last one, so the
// caller's single dma_queue_pop drains the final descriptor.
if (rows_done < sub_nrows)
dma_queue_pop_nowait(q);
}
return true;
}
// Case 3: stride overflow — fall back to row-by-row.
{
const uint8_t *src = (const uint8_t *)dptr.src;
uint8_t *dst = (uint8_t *)dptr.dst;
for (size_t r = 0; r < nrows; ++r) {
dma_ptr p = dma_make_ptr(dst + r * dst_stride,
src + r * src_stride);
if (!dma_queue_push(q, p, 0, 0, width, 1))
return false;
if (r + 1 < nrows)
dma_queue_pop_nowait(q);
}
return true;
}
}
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif

View File

@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt; return pktcnt;
} }
static inline int32_t hex_is_aligned(void * addr, uint32_t align) { static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0; return ((size_t) addr & (align - 1)) == 0;
} }
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1); uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n; uint32_t right_off = left_off + n;
@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m); return m * ((n + m - 1) / m);
} }
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control); Q6_l2fetch_AP((void *) p, control);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#ifndef restrict
# define restrict __restrict
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct htp_context; // forward declaration
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_w16a32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx,
const hmx_matmul_w16a32_batched_params_t *params);
// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL)
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int weight_type);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H

View File

@ -0,0 +1,34 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H

View File

@ -0,0 +1,88 @@
// HMX tile-level inline helpers (FP16 32x32 tile operations).
// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_UTILS_H
#define HMX_UTILS_H
#include <hexagon_types.h>
#include <stddef.h>
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
}
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
*pv++ = v_scale;
*pv = Q6_V_vzero();
}
// Load multiple contiguous tiles with :deep streaming.
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
// boundary, otherwise the mxmem instruction will raise a precise bus error.
// Callers must ensure their VTCM layout satisfies this constraint.
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
asm volatile(
"{ activation.hf = mxmem(%0, %1):deep\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
: "memory");
}
// Load a single activation+weight tile pair (no :deep streaming).
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
const __fp16 *wt_tile) {
asm volatile(
"{ activation.hf = mxmem(%0, %1)\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(act_tile), "r"(2047),
"r"(wt_tile), "r"(2047)
: "memory");
}
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
// Use the combined convert-and-store instruction (matches the reference
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
asm volatile(
"mxmem(%0, %1):after.hf = acc\n"
:: "r"(out), "r"(0)
: "memory");
}
// Compute inner product of two vectors of tiles and store result.
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
hmx_consume_accumulator_fp16(out);
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
uint8_t *p = *vtcm_ptr;
*vtcm_ptr += size;
return p;
}
#endif // HMX_UTILS_H

View File

@ -30,6 +30,12 @@ struct htp_context {
atomic_bool vtcm_needs_release; atomic_bool vtcm_needs_release;
uint32_t opmask; uint32_t opmask;
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
#ifdef HTP_HAS_HMX
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
#endif
}; };
#endif /* HTP_CTX_H */ #endif /* HTP_CTX_H */

View File

@ -32,13 +32,14 @@ enum htp_status {
// Duplicated here because we can't include full ggml.h in the htp build. // Duplicated here because we can't include full ggml.h in the htp build.
// We have some static_asserts in the cpp code to ensure things are in sync. // We have some static_asserts in the cpp code to ensure things are in sync.
enum htp_data_type { enum htp_data_type {
HTP_TYPE_F32 = 0, HTP_TYPE_F32 = 0,
HTP_TYPE_F16 = 1, HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8, HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26, HTP_TYPE_IQ4_NL = 20,
HTP_TYPE_I64 = 27, HTP_TYPE_I32 = 26,
HTP_TYPE_MXFP4 = 39, HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT HTP_TYPE_COUNT
}; };
@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) {
return QK4_0; return QK4_0;
case HTP_TYPE_Q8_0: case HTP_TYPE_Q8_0:
return QK8_0; return QK8_0;
case HTP_TYPE_IQ4_NL:
return QK4_NL;
case HTP_TYPE_MXFP4: case HTP_TYPE_MXFP4:
return QK_MXFP4; return QK_MXFP4;
default: default:
@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return sizeof(block_q4_0); return sizeof(block_q4_0);
case HTP_TYPE_Q8_0: case HTP_TYPE_Q8_0:
return sizeof(block_q8_0); return sizeof(block_q8_0);
case HTP_TYPE_IQ4_NL:
return sizeof(block_iq4_nl);
case HTP_TYPE_MXFP4: case HTP_TYPE_MXFP4:
return sizeof(block_mxfp4); return sizeof(block_mxfp4);
default: default:

View File

@ -7,7 +7,7 @@
#include "remote.idl" #include "remote.idl"
interface htp_iface : remote_handle64 { interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
AEEResult stop(); AEEResult stop();
AEEResult enable_etm(); AEEResult enable_etm();
AEEResult disable_etm(); AEEResult disable_etm();

View File

@ -9,6 +9,9 @@
#include "hex-utils.h" #include "hex-utils.h"
#include "hvx-types.h" #include "hvx-types.h"
#define hvx_vmem(A) *((HVX_Vector *)(A))
#define hvx_vmemu(A) *((HVX_UVector *)(A))
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
// Rotate as needed. // Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) dst); v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
return Q6_Q_and_QQ(p_exp, p_frac); return Q6_Q_and_QQ(p_exp, p_frac);
} }
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vsplat_R(0); const HVX_Vector zero = Q6_V_vzero();
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79 #if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
return v; return v;
} }
#if __HVX_ARCH__ >= 79
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p));
}
#else
static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
const HVX_Vector one = hvx_vec_splat_f16(1.0);
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one);
return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)));
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/ /* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73 #if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)

View File

@ -25,6 +25,10 @@
#include "htp-ops.h" #include "htp-ops.h"
#include "worker-pool.h" #include "worker-pool.h"
#ifdef HTP_HAS_HMX
#include "hmx-ops.h"
#endif // HTP_HAS_HMX
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
struct htp_context * ctx; struct htp_context * ctx;
int err = 0; int err = 0;
@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) {
} }
ctx->vtcm_inuse = true; ctx->vtcm_inuse = true;
return 0; return 0;
} }
@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) {
struct htp_context * ctx = (struct htp_context *) handle; struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) { if (!ctx) {
@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY; return AEE_ENOMEMORY;
} }
#ifdef HTP_HAS_HMX
if (use_hmx) {
ctx->vtcm_scratch_size = ctx->vtcm_size;
ctx->hmx_enabled = 1;
FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size);
} else {
// HMX disabled: skip HMX initialisation so the
// dispatch loop falls through to the HVX compute paths.
ctx->hmx_enabled = 0;
ctx->vtcm_scratch_size = ctx->vtcm_size;
FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size);
}
#endif
qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads); qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
for (int i = 0; i < ctx->n_threads; i++) { for (int i = 0; i < ctx->n_threads; i++) {
dma_queue_delete(ctx->dma[i]); dma_queue_delete(ctx->dma[i]);
} }
#ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
ctx->hmx_enabled = 0;
}
#endif
vtcm_free(ctx); vtcm_free(ctx);
@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c,
struct dspqueue_buffer * bufs, struct dspqueue_buffer * bufs,
size_t n_bufs, size_t n_bufs,
struct profile_data * prof) { struct profile_data * prof) {
// Prep response struct // Prep response struct (zero-init to clear cmp/unused union)
struct htp_general_rsp rsp; struct htp_general_rsp rsp;
memset(&rsp, 0, sizeof(rsp));
rsp.op = op; rsp.op = op;
rsp.status = status; rsp.status = status;
rsp.prof_usecs = prof->usecs; rsp.prof_usecs = prof->usecs;
@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
} }
#ifdef HTP_HAS_HMX
// ---------------------------------------------------------------------------
// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad.
// VTCM, DMA and thread dispatch are managed inside the HMX kernels.
// ---------------------------------------------------------------------------
static void proc_hmx_matmul_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// HMX weight tile requires N to be 32-aligned.
if (req->src0.ne[1] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 ||
req->src1.ne[2] * req->src1.ne[3] > 1);
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
// batched quantised, but guard here too). F16 batched matmul is handled
// by the dedicated wrapper in hmx-matmul-ops.c.
if (is_batched &&
req->src0.type != HTP_TYPE_F16) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX assumes contiguous row-major layout. Fall back for permuted
// tensors where strides are non-monotonic (e.g. transposed KV cache).
if (req->src0.nb[0] > req->src0.nb[1] ||
req->src1.nb[0] > req->src1.nb[1]) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
const int m_total = (int) req->src1.ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
if (m_hmx == 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
// Other types (e.g. MXFP4) fall back to HVX.
{
uint32_t wtype = req->src0.type;
if (wtype != HTP_TYPE_F16 &&
wtype != HTP_TYPE_Q4_0 &&
wtype != HTP_TYPE_Q8_0 &&
wtype != HTP_TYPE_IQ4_NL) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
// F16 HMX path requires K aligned to 32 (tile width).
if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
}
(void) n_bufs;
struct dspqueue_buffer rsp_bufs[1];
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);
// src0 = weights, src1 = activation, dst = output
void * wgt = (void *) bufs[0].ptr;
float * act = (float *) bufs[1].ptr;
float * dst = (float *) bufs[2].ptr;
int k = (int) req->src0.ne[0]; // inner dimension
int n = (int) req->src0.ne[1]; // weight columns
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
int ret = -1;
const int ne02 = (int) req->src0.ne[2];
const int ne03 = (int) req->src0.ne[3];
const int ne12 = (int) req->src1.ne[2];
const int ne13 = (int) req->src1.ne[3];
// Row strides in elements. For compact tensors these equal k; for
// permuted attention views they can be larger, so pass the real stride.
const int act_stride = (int)(req->src1.nb[1] / sizeof(float));
const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16));
switch (req->src0.type) {
case HTP_TYPE_F16:
if (is_batched) {
hmx_matmul_w16a32_batched_params_t batch_params = {
.dst = dst,
.activation = act,
.permuted_weight = (const __fp16 *) wgt,
.m = m_hmx,
.k = k,
.n = n,
.act_stride = act_stride,
.weight_stride = weight_stride,
.dst_stride = (int)(req->dst.nb[1] / sizeof(float)),
.ne02 = ne02,
.ne03 = ne03,
.ne12 = ne12,
.ne13 = ne13,
.src0_nb2 = req->src0.nb[2],
.src0_nb3 = req->src0.nb[3],
.src1_nb2 = req->src1.nb[2],
.src1_nb3 = req->src1.nb[3],
.dst_nb2 = req->dst.nb[2],
.dst_nb3 = req->dst.nb[3],
};
ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params);
} else {
ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act,
(const __fp16 *) wgt,
m_hmx, k, n,
act_stride,
weight_stride);
}
break;
default:
ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act,
(const uint8_t *) wgt,
m_hmx, k, n, (int) req->src0.type);
break;
}
if (ret == 0) {
rsp_status = HTP_STATUS_OK;
} else {
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
vtcm_release(ctx);
req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE;
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
vtcm_release(ctx);
}
// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0 && rsp_status == HTP_STATUS_OK) {
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0; // weights: unchanged
octx.src1 = req->src1;
octx.src1.ne[1] = m_tail; // only tail rows
octx.dst = req->dst;
octx.dst.ne[1] = m_tail; // only tail rows
// Always re-quantize tail src1: HMX Phase 1 overwrites VTCM,
// so any previously cached quantized data (SKIP_QUANTIZE pipeline)
// is invalid.
octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE;
octx.op = req->op;
octx.n_threads = ctx->n_threads;
// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]);
octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]);
FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p",
m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data);
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
uint32_t hvx_ret = op_matmul(&octx);
vtcm_release(ctx);
if (hvx_ret != HTP_STATUS_OK) {
FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret);
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
} else {
rsp_status = HTP_STATUS_INTERNAL_ERR;
}
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
#endif // HTP_HAS_HMX
static void htp_packet_callback(dspqueue_t queue, int error, void * context) { static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context; struct htp_context * ctx = (struct htp_context *) context;
@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
FARF(ERROR, "Bad matmul-req buffer list"); FARF(ERROR, "Bad matmul-req buffer list");
continue; continue;
} }
proc_matmul_req(ctx, &req, bufs, n_bufs); #ifdef HTP_HAS_HMX
if (ctx->hmx_enabled) {
proc_hmx_matmul_req(ctx, &req, bufs, n_bufs);
} else
#endif
{
proc_matmul_req(ctx, &req, bufs, n_bufs);
}
break; break;
case HTP_OP_MUL_MAT_ID: case HTP_OP_MUL_MAT_ID:

View File

@ -53,9 +53,6 @@ endif()
message(STATUS "HIP and hipBLAS found") message(STATUS "HIP and hipBLAS found")
# Workaround old compilers
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
@ -132,6 +129,11 @@ endif()
if (CXX_IS_HIPCC) if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug")
# CMake on Windows doesn't support the HIP language yet.
# Therefore we workaround debug build's failure on HIP backend this way.
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g")
endif()
target_link_libraries(ggml-hip PRIVATE hip::device) target_link_libraries(ggml-hip PRIVATE hip::device)
else() else()
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)

View File

@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 64; // 4 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 64) / 8; // 0..7
const uint iq = 16 * ib32 + 2 * (idx % 8); const uint iq = 4 * ib32 + (idx % 4);
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 8) >> 1; const uint qshift = idx & 4;
u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F);
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;

View File

@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string load_vec_quant = "2"; std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8"; load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4"; load_vec_quant = "4";
if (tname == "bf16") { if (tname == "bf16") {

View File

@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0; uint32_t wg_size = 0;
}; };
struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
};
/** Argsort **/ /** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context { struct ggml_webgpu_argsort_shader_lib_context {
@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
uint32_t wg_size; uint32_t wg_size;
}; };
/** Set **/
struct ggml_webgpu_set_pipeline_key {
ggml_type type;
bool inplace;
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
return type == other.type && inplace == other.inplace;
}
};
struct ggml_webgpu_set_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Get Rows **/ /** Get Rows **/
struct ggml_webgpu_get_rows_pipeline_key { struct ggml_webgpu_get_rows_pipeline_key {
@ -151,6 +176,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
} }
}; };
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
}
};
struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Pad **/ /** Pad **/
struct ggml_webgpu_pad_pipeline_key { struct ggml_webgpu_pad_pipeline_key {
bool circular; bool circular;
@ -166,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
} }
}; };
/** Solve Tri **/
struct ggml_webgpu_solve_tri_pipeline_key {
int type;
int n;
int k;
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
return type == other.type && n == other.n && k == other.k;
}
};
struct ggml_webgpu_solve_tri_pipeline_key_hash {
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.n);
ggml_webgpu_hash_combine(seed, key.k);
return seed;
}
};
/** SSM Conv **/
struct ggml_webgpu_ssm_conv_pipeline_key {
int type;
int vectorized;
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
return type == other.type && vectorized == other.vectorized;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
int s_v;
int kda;
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
return type == other.type && s_v == other.s_v && kda == other.kda;
}
};
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.s_v);
ggml_webgpu_hash_combine(seed, key.kda);
return seed;
}
};
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
/** Scale **/ /** Scale **/
struct ggml_webgpu_scale_pipeline_key { struct ggml_webgpu_scale_pipeline_key {
@ -244,13 +350,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
/** Unary **/ /** Unary **/
struct ggml_webgpu_unary_pipeline_key { struct ggml_webgpu_unary_pipeline_key {
int type; int type;
int op; int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace; bool inplace;
ggml_tri_type ttype; // only used for GGML_OP_TRI
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
ttype == other.ttype;
} }
}; };
@ -261,6 +369,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.ttype);
return seed; return seed;
} }
}; };
@ -435,20 +544,30 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash> std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash> std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines; // type/op/inplace unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash> std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
solve_tri_pipelines; // type
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
ssm_conv_pipelines; // type/vectorized
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
webgpu_pipeline,
ggml_webgpu_gated_delta_net_pipeline_key_hash>
gated_delta_net_pipelines; // type/S_v/kda
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
pad_pipelines; // circular/non-circular pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash> std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines; // type/op/inplace/overlap binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash> std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash> std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash> std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines; flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key, std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@ -462,6 +581,7 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash> std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines; set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
public: public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -479,6 +599,45 @@ class ggml_webgpu_shader_lib {
return sum_rows_pipelines[1]; return sum_rows_pipelines[1];
} }
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {
.op = context.dst->op,
.inplace = context.inplace,
};
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant;
switch (key.op) {
case GGML_OP_RMS_NORM:
defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("L2_NORM");
variant = "l2_norm";
break;
default:
GGML_ABORT("Unsupported op for row_norm shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
}
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0; bool vec4 = context.src0->ne[0] % 4 == 0;
@ -546,6 +705,46 @@ class ggml_webgpu_shader_lib {
return set_rows_pipelines[key]; return set_rows_pipelines[key];
} }
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "set";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported type for set shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_set, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
set_pipelines[key] = pipeline;
return set_pipelines[key];
}
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = cumsum_pipelines.find(1); auto it = cumsum_pipelines.find(1);
if (it != cumsum_pipelines.end()) { if (it != cumsum_pipelines.end()) {
@ -632,6 +831,7 @@ class ggml_webgpu_shader_lib {
switch (key.src_type) { switch (key.src_type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
defines.push_back("FLOAT_PARALLEL");
if (key.vectorized) { if (key.vectorized) {
defines.push_back("F32_VEC"); defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>"); defines.push_back("SRC_TYPE=vec4<f32>");
@ -646,6 +846,7 @@ class ggml_webgpu_shader_lib {
variant += "_f32"; variant += "_f32";
break; break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("F16"); defines.push_back("F16");
defines.push_back("SRC_TYPE=f16"); defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32"); defines.push_back("DST_TYPE=f32");
@ -653,6 +854,7 @@ class ggml_webgpu_shader_lib {
variant += "_f16"; variant += "_f16";
break; break;
case GGML_TYPE_I32: case GGML_TYPE_I32:
defines.push_back("FLOAT_PARALLEL");
defines.push_back("I32"); defines.push_back("I32");
defines.push_back("SRC_TYPE=i32"); defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32"); defines.push_back("DST_TYPE=i32");
@ -731,6 +933,128 @@ class ggml_webgpu_shader_lib {
return scale_pipelines[key]; return scale_pipelines[key];
} }
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_solve_tri_pipeline_key key = {
.type = context.dst->type,
.n = (int) context.src0->ne[0],
.k = (int) context.src1->ne[0],
};
auto it = solve_tri_pipelines.find(key);
if (it != solve_tri_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "solve_tri";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for solve_tri shader");
}
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
const uint32_t k_tile = wg_size;
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
defines.push_back(std::string("N=") + std::to_string(key.n));
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
solve_tri_pipelines[key] = pipeline;
return solve_tri_pipelines[key];
}
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_ssm_conv_pipeline_key key = {
.type = context.dst->type,
.vectorized = context.src1->ne[0] == 4,
};
auto it = ssm_conv_pipelines.find(key);
if (it != ssm_conv_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "ssm_conv";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for ssm_conv shader");
}
if (key.vectorized) {
defines.push_back("VECTORIZED");
variant += "_vec4";
}
constexpr uint32_t block_size = 32u;
constexpr uint32_t tokens_per_wg = 8u;
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
decisions->block_size = block_size;
decisions->tokens_per_wg = tokens_per_wg;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
ssm_conv_pipelines[key] = pipeline;
return ssm_conv_pipelines[key];
}
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_gated_delta_net_pipeline_key key = {
.type = context.dst->type,
.s_v = (int) context.src2->ne[0],
.kda = context.src3->ne[0] == context.src2->ne[0],
};
auto it = gated_delta_net_pipelines.find(key);
if (it != gated_delta_net_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "gated_delta_net";
switch (key.type) {
case GGML_TYPE_F32:
variant += "_f32";
break;
default:
GGML_ABORT("Unsupported type for gated_delta_net shader");
}
if (key.kda) {
defines.push_back("KDA");
variant += "_kda";
}
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
gated_delta_net_pipelines[key] = pipeline;
return gated_delta_net_pipelines[key];
}
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
@ -1058,6 +1382,7 @@ class ggml_webgpu_shader_lib {
.op = op, .op = op,
.is_unary = is_unary, .is_unary = is_unary,
.inplace = context.inplace, .inplace = context.inplace,
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
}; };
auto it = unary_pipelines.find(key); auto it = unary_pipelines.find(key);
@ -1088,6 +1413,29 @@ class ggml_webgpu_shader_lib {
variant += "_inplace"; variant += "_inplace";
} }
if (op == GGML_OP_TRI) {
switch (key.ttype) {
case GGML_TRI_TYPE_LOWER:
defines.push_back("TRI_TYPE_LOWER");
variant += "_tri_type_lower";
break;
case GGML_TRI_TYPE_LOWER_DIAG:
defines.push_back("TRI_TYPE_LOWER_DIAG");
variant += "_tri_type_lower_diag";
break;
case GGML_TRI_TYPE_UPPER:
defines.push_back("TRI_TYPE_UPPER");
variant += "_tri_type_upper";
break;
case GGML_TRI_TYPE_UPPER_DIAG:
defines.push_back("TRI_TYPE_UPPER_DIAG");
variant += "_tri_upper_diag";
break;
default:
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
}
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines); auto processed = preprocessor.preprocess(wgsl_unary, defines);

View File

@ -366,7 +366,6 @@ struct webgpu_context_struct {
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
@ -881,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
params, entries, wg_x); params, entries, wg_x);
} }
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1u,
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries;
uint32_t binding_index = 0;
if (!inplace) {
entries.push_back({ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
binding_index++;
}
entries.push_back({ .binding = binding_index,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
entries.push_back({ .binding = binding_index + 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = { ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
@ -936,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
} }
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) src1->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
token_tiles,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * src3,
ggml_tensor * src4,
ggml_tensor * src5,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.src3 = src3,
.src4 = src4,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
const uint32_t s_v = (uint32_t) src2->ne[0];
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
std::vector<uint32_t> params = {
h,
n_tokens,
n_seqs,
s_v * h * n_tokens * n_seqs,
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
(uint32_t) src0->ne[1],
(uint32_t) (src2->ne[3] / src0->ne[3]),
scale_u32,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
{ .binding = 3,
.buffer = ggml_webgpu_tensor_buf(src3),
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
{ .binding = 4,
.buffer = ggml_webgpu_tensor_buf(src4),
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
{ .binding = 5,
.buffer = ggml_webgpu_tensor_buf(src5),
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
{ .binding = 6,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
}
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx, static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src, ggml_tensor * src,
ggml_tensor * idx, ggml_tensor * idx,
@ -1017,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * src, ggml_tensor * src,
ggml_tensor * idx, ggml_tensor * idx,
ggml_tensor * dst) { ggml_tensor * dst) {
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
ggml_webgpu_shader_lib_context shader_lib_ctx = { ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .src0 = src,
.src1 = nullptr, .src1 = nullptr,
@ -1061,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) } .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
}; };
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
} }
@ -1598,8 +1866,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
} }
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst); bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@ -1630,8 +1898,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }); .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
} }
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, ggml_webgpu_shader_lib_context shader_lib_ctx = {
entries, ggml_nrows(src)); .src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
} }
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@ -2170,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:
return ggml_webgpu_cpy(ctx, src0, node); return ggml_webgpu_cpy(ctx, src0, node);
case GGML_OP_SET:
return ggml_webgpu_set(ctx, src0, src1, node);
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
return ggml_webgpu_set_rows(ctx, src0, src1, node); return ggml_webgpu_set_rows(ctx, src0, src1, node);
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
@ -2192,7 +2469,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node); return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE: case GGML_OP_ROPE:
return ggml_webgpu_rope(ctx, src0, src1, src2, node); return ggml_webgpu_rope(ctx, src0, src1, src2, node);
case GGML_OP_GLU: case GGML_OP_GLU:
@ -2209,7 +2487,15 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SQRT: case GGML_OP_SQRT:
case GGML_OP_SIN: case GGML_OP_SIN:
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_DIAG:
case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node); return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SOLVE_TRI:
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
case GGML_OP_SSM_CONV:
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
case GGML_OP_GATED_DELTA_NET:
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
case GGML_OP_PAD: case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node); return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
@ -2614,15 +2900,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
} }
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
webgpu_ctx->rms_norm_pipelines[0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
@ -2907,7 +3184,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@ -2958,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .is_host = */ NULL, // defaults to false /* .is_host = */ NULL, // defaults to false
}, },
/* .device = */ /* .device = */
dev, dev,
/* .context = */ NULL /* .context = */ NULL
}; };
@ -3041,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
break; break;
case GGML_OP_SET:
supports_op = src0->type == src1->type && src0->type == op->type &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
break;
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
@ -3118,6 +3398,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break; break;
} }
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break; break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
} }
} }
break; break;
case GGML_OP_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_DIAG:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_SOLVE_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t s_v = (uint32_t) src2->ne[0];
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
}
break;
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break; break;

View File

@ -0,0 +1,132 @@
@group(0) @binding(0)
var<storage, read_write> src_q: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src_k: array<f32>;
@group(0) @binding(2)
var<storage, read_write> src_v: array<f32>;
@group(0) @binding(3)
var<storage, read_write> src_g: array<f32>;
@group(0) @binding(4)
var<storage, read_write> src_beta: array<f32>;
@group(0) @binding(5)
var<storage, read_write> src_state: array<f32>;
@group(0) @binding(6)
var<storage, read_write> dst: array<f32>;
struct Params {
h: u32,
n_tokens: u32,
n_seqs: u32,
s_off: u32,
sq1: u32,
sq2: u32,
sq3: u32,
sv1: u32,
sv2: u32,
sv3: u32,
sb1: u32,
sb2: u32,
sb3: u32,
neq1: u32,
rq3: u32,
scale: f32,
};
@group(0) @binding(7)
var<uniform> params: Params;
var<workgroup> sh_k: array<f32, S_V>;
var<workgroup> sh_q: array<f32, S_V>;
var<workgroup> sh_g: array<f32, S_V>;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let head_id = workgroup_id.x;
let seq_id = workgroup_id.y;
let col = local_id.x;
let iq1 = head_id % params.neq1;
let iq3 = seq_id / params.rq3;
let state_size = S_V * S_V;
let state_base = (seq_id * params.h + head_id) * state_size;
var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
state[i] = src_state[state_base + col * S_V + i];
}
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
for (var t = 0u; t < params.n_tokens; t++) {
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
let k_off = q_off;
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
sh_q[col] = src_q[q_off + col];
sh_k[col] = src_k[k_off + col];
#ifdef KDA
let g_base = gb_off * S_V;
sh_g[col] = exp(src_g[g_base + col]);
#endif
workgroupBarrier();
let v_val = src_v[v_off + col];
let beta_val = src_beta[gb_off];
var kv_col = 0.0;
var delta_col = 0.0;
var attn_col = 0.0;
#ifdef KDA
for (var i = 0u; i < S_V; i++) {
kv_col += (sh_g[i] * state[i]) * sh_k[i];
}
delta_col = (v_val - kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#else
let g_val = exp(src_g[gb_off]);
for (var i = 0u; i < S_V; i++) {
kv_col += state[i] * sh_k[i];
}
delta_col = (v_val - g_val * kv_col) * beta_val;
for (var i = 0u; i < S_V; i++) {
state[i] = g_val * state[i] + sh_k[i] * delta_col;
attn_col += state[i] * sh_q[i];
}
#endif
dst[attn_off + col] = attn_col * params.scale;
attn_off += S_V * params.h;
workgroupBarrier();
}
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_base + col * S_V + i] = state[i];
}
}

View File

@ -640,6 +640,35 @@ var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE) @compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) { fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef FLOAT_PARALLEL
let blocks_per_row = params.ne0 / BLOCK_SIZE;
let row_count = params.n_rows * params.ne2 * params.ne3;
if (gid.x >= blocks_per_row * row_count) {
return;
}
let block_idx = gid.x % blocks_per_row;
var row_idx = gid.x / blocks_per_row;
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
row_idx = row_idx % (params.ne2 * params.n_rows);
let i_dst2 = row_idx / params.n_rows;
let i_dst1 = row_idx % params.n_rows;
let i_idx2 = i_dst3 % params.idx2;
let i_idx1 = i_dst2 % params.idx1;
let i_idx0 = i_dst1;
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = u32(idx[i_idx]);
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
copy_elements(i_src_row, i_dst_row, block_idx);
#else
if (gid.x >= params.n_rows * params.ne2 * params.ne3) { if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return; return;
} }
@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i); copy_elements(i_src_row, i_dst_row, i);
} }
#endif
} }

View File

@ -1,21 +1,11 @@
#define(VARIANTS) #ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
[ src[dst_offset] = scale * src[src_offset];
{ }
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_SUFFIX": "inplace",
"DECLS": ["INPLACE"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) { fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset]; dst[dst_offset] = scale * src[src_offset];
} }
@ -25,23 +15,7 @@ var<storage, read_write> dst: array<f32>;
@group(0) @binding(2) @group(0) @binding(2)
var<uniform> params: Params; var<uniform> params: Params;
#endif
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
struct Params { struct Params {
offset_src: u32, // in elements offset_src: u32, // in elements
@ -68,12 +42,9 @@ struct Params {
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src: array<f32>; var<storage, read_write> src: array<f32>;
DECLS var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32; @compute @workgroup_size(WG_SIZE)
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
fn main(@builtin(workgroup_id) wid: vec3<u32>, fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) { @builtin(local_invocation_id) lid: vec3<u32>) {
@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size; let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var sum = 0.0f; var sum = 0.0f;
var col = lid.x; var col = lid.x;
@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break; break;
} }
sum += pow(src[i_src_row + col], 2.0); sum += pow(src[i_src_row + col], 2.0);
col += wg_size; col += WG_SIZE;
} }
scratch[lid.x] = sum; scratch[lid.x] = sum;
workgroupBarrier(); workgroupBarrier();
var offset = wg_size / 2; var offset: u32 = WG_SIZE / 2;
while (offset > 0) { while (offset > 0) {
if (lid.x < offset) { if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset]; scratch[lid.x] += scratch[lid.x + offset];
@ -110,14 +81,18 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
} }
sum = scratch[0]; sum = scratch[0];
#ifdef RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
col = lid.x; col = lid.x;
for (var j: u32 = 0; j < elems; j++) { for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) { if (col >= params.ne0) {
break; break;
} }
update(i_src_row + col, i_dst_row + col, scale); update(i_src_row + col, i_dst_row + col, scale);
col += wg_size; col += WG_SIZE;
} }
} }
#end(SHADER)

View File

@ -0,0 +1,109 @@
#ifdef TYPE_I32
#define TYPE i32
#else
#define TYPE f32
#endif
#ifndef INPLACE
@group(0) @binding(0)
var<storage, read_write> src0: array<TYPE>;
#define SRC1_BINDING 1
#else
#define SRC1_BINDING 0
#endif
#define DST_BINDING SRC1_BINDING + 1
#define PARAMS_BINDING SRC1_BINDING + 2
@group(0) @binding(SRC1_BINDING)
var<storage, read_write> src1: array<TYPE>;
@group(0) @binding(DST_BINDING)
var<storage, read_write> dst: array<TYPE>;
struct Params {
ne: u32,
offset_src0: u32,
offset_src1: u32,
offset_view: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst10: u32,
stride_dst11: u32,
stride_dst12: u32,
stride_dst13: u32,
src1_ne0: u32,
src1_ne1: u32,
src1_ne2: u32,
src1_ne3: u32,
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
fn decode_src1_coords(idx: u32) -> vec4<u32> {
var i = idx;
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
let i3 = i / plane;
i = i % plane;
let row = params.src1_ne1 * params.src1_ne0;
let i2 = i / row;
i = i % row;
let i1 = i / params.src1_ne0;
let i0 = i % params.src1_ne0;
return vec4<u32>(i0, i1, i2, i3);
}
fn decode_view_coords(rel: u32) -> vec4<u32> {
let i3 = rel / params.stride_dst13;
let rem3 = rel % params.stride_dst13;
let i2 = rem3 / params.stride_dst12;
let rem2 = rem3 % params.stride_dst12;
let i1 = rem2 / params.stride_dst11;
let i0 = rem2 % params.stride_dst11;
return vec4<u32>(i0, i1, i2, i3);
}
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
}
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
coords.z * params.stride_src12 + coords.w * params.stride_src13;
}
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
return view_rel_from_coords(coords) == rel;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
#ifdef INPLACE
let coords = decode_src1_coords(gid.x);
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
let dst_idx = params.offset_view + view_rel_from_coords(coords);
dst[dst_idx] = src1[src1_idx];
#else
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
let coords = decode_view_coords(rel);
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
} else {
dst[gid.x] = src0[params.offset_src0 + gid.x];
}
#endif
}

View File

@ -0,0 +1,121 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src00: u32,
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src10: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
k: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
var<workgroup> shA: array<f32, BATCH_N * N>;
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src0 +
col * params.stride_src00 +
row * params.stride_src01 +
i2 * params.stride_src02 +
i3 * params.stride_src03;
}
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_src1 +
col * params.stride_src10 +
row * params.stride_src11 +
i2 * params.stride_src12 +
i3 * params.stride_src13;
}
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
return params.offset_dst +
col * params.stride_dst0 +
row * params.stride_dst1 +
i2 * params.stride_dst2 +
i3 * params.stride_dst3;
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let batch = workgroup_id.y;
let col = workgroup_id.x * WG_SIZE + local_id.x;
let i3 = batch / params.ne2;
let i2 = batch % params.ne2;
let active_lane = local_id.x < K_TILE;
let active_col = active_lane && col < params.k;
var X: array<f32, N>;
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
let cur_n = min(BATCH_N, N - row_base);
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
let tile_row = i / N;
let tile_col = i % N;
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
}
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
let tile_row = i / K_TILE;
let tile_col = i % K_TILE;
let global_col = workgroup_id.x * WG_SIZE + tile_col;
let sh_idx = tile_row * K_TILE + tile_col;
if (global_col < params.k) {
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
} else {
shB[sh_idx] = 0.0;
}
}
workgroupBarrier();
if (active_col) {
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
let r = row_base + row_offset;
var b = shB[row_offset * K_TILE + local_id.x];
let a_row = row_offset * N;
for (var t = 0u; t < r; t++) {
b -= shA[a_row + t] * X[t];
}
let x = b / shA[a_row + r];
X[r] = x;
dst[dst_idx(r, col, i2, i3)] = x;
}
}
workgroupBarrier();
}
}

View File

@ -0,0 +1,65 @@
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src01: u32,
stride_src02: u32,
stride_src11: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
nc: u32,
nr: u32,
n_t: u32,
n_s: u32,
token_tiles: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i1 = gid.x;
let tile_y = gid.y / TOKENS_PER_WG;
let local_token = gid.y % TOKENS_PER_WG;
let i3 = tile_y / params.token_tiles;
let token_tile = tile_y % params.token_tiles;
let i2 = token_tile * TOKENS_PER_WG + local_token;
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
return;
}
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
let src1_base = params.offset_src1 + i1 * params.stride_src11;
var sum = 0.0;
#ifdef VECTORIZED
sum =
src0[src0_base + 0u] * src1[src1_base + 0u] +
src0[src0_base + 1u] * src1[src1_base + 1u] +
src0[src0_base + 2u] * src1[src1_base + 2u] +
src0[src0_base + 3u] * src1[src1_base + 3u];
#else
for (var i0 = 0u; i0 < params.nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
#endif
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
dst[dst_idx] = sum;
}

View File

@ -5,7 +5,6 @@ enable f16;
#define TYPE f32 #define TYPE f32
#endif #endif
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src: array<TYPE>; var<storage, read_write> src: array<TYPE>;
@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return; return;
} }
var i = gid.x; var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0); let ne2 = params.ne2;
i = i % (params.ne2 * params.ne1 * params.ne0); #ifdef DIAG
let i2 = i / (params.ne1 * params.ne0); let ne1 = params.ne0;
i = i % (params.ne1 * params.ne0); #else
let i1 = i / params.ne0; let ne1 = params.ne1;
let i0 = i % params.ne0; #endif
let ne0 = params.ne0;
let i3 = i / (ne2 * ne1 * ne0);
i = i % (ne2 * ne1 * ne0);
let i2 = i / (ne1 * ne0);
i = i % (ne1 * ne0);
let i1 = i / ne0;
let i0 = i % ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3; i2 * params.stride_src2 + i3 * params.stride_src3;
@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res_f32 = cos(f32(src[params.offset_src + src_idx])); let res_f32 = cos(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32); let res = TYPE(res_f32);
#endif #endif
#ifdef DIAG
let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
#endif
#ifdef TRI
#ifdef TRI_TYPE_LOWER
let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
#elif TRI_TYPE_LOWER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
#elif TRI_TYPE_UPPER
let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
#elif TRI_TYPE_UPPER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
#endif
#endif
#ifdef INPLACE #ifdef INPLACE
src[params.offset_src + src_idx] = res; src[params.offset_src + src_idx] = res;

View File

@ -425,8 +425,7 @@ class GGUFWriter:
fout = self.fout[file_id] fout = self.fout[file_id]
# pop the first tensor info # pop the first tensor info
# TODO: cleaner way to get the first key first_tensor_name = next(iter(self.tensors[file_id]))
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
ti = self.tensors[file_id].pop(first_tensor_name) ti = self.tensors[file_id].pop(first_tensor_name)
assert ti.nbytes == tensor.nbytes assert ti.nbytes == tensor.nbytes

View File

@ -7,7 +7,6 @@
{%- set available_tool_string = '' -%} {%- set available_tool_string = '' -%}
{%- set add_tool_id = true -%} {%- set add_tool_id = true -%}
{%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #} {%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #}
{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #}
{# Optional token placeholders (safe defaults) #} {# Optional token placeholders (safe defaults) #}
{%- set bos_token = bos_token or '' -%} {%- set bos_token = bos_token or '' -%}
{%- set eos_token = eos_token or '' -%} {%- set eos_token = eos_token or '' -%}

View File

@ -15,10 +15,10 @@
{%- set ns.is_tool = false -%} {%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls']-%} {%- for tool in message['tool_calls']-%}
{%- if not ns.is_first -%} {%- if not ns.is_first -%}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}} {{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- set ns.is_first = true -%} {%- set ns.is_first = true -%}
{%- else -%} {%- else -%}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}} {{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] | tojson + '\n' + '```' + '<tool▁call▁end>'}}
{%- endif -%} {%- endif -%}
{%- endfor -%} {%- endfor -%}
{{'<tool▁calls▁end><end▁of▁sentence>'}} {{'<tool▁calls▁end><end▁of▁sentence>'}}

View File

@ -28,25 +28,25 @@
{%- set ns.is_last_user = true -%}{{'<User>' + message['content']}} {%- set ns.is_last_user = true -%}{{'<User>' + message['content']}}
{%- endif -%} {%- endif -%}
{%- if message['role'] == 'assistant' and message['tool_calls'] -%} {%- if message['role'] == 'assistant' and message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant></think>'}} {%- if ns.is_last_user -%}{{'<Assistant><think></think>'}}
{%- endif -%} {%- endif -%}
{%- set ns.is_last_user = false -%} {%- set ns.is_last_user = false -%}
{%- set ns.is_first = false -%} {%- set ns.is_first = false -%}
{%- set ns.is_tool = false -%} {%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls'] -%} {%- for tool in message['tool_calls'] -%}
{%- if not ns.is_first -%} {%- if not ns.is_first -%}
{%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}} {%- if not message['content'] -%}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}} {%- else -%}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%} {%- endif -%}
{%- set ns.is_first = true -%} {%- set ns.is_first = true -%}
{%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}} {%- else -%}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] | tojson + '<tool▁call▁end>'}}
{%- endif -%} {%- endif -%}
{%- endfor -%}{{'<tool▁calls▁end><end▁of▁sentence>'}} {%- endfor -%}{{'<tool▁calls▁end><end▁of▁sentence>'}}
{%- endif -%} {%- endif -%}
{%- if message['role'] == 'assistant' and not message['tool_calls'] -%} {%- if message['role'] == 'assistant' and not message['tool_calls'] -%}
{%- if ns.is_last_user -%}{{'<Assistant>'}} {%- if ns.is_last_user -%}{{'<Assistant>'}}
{%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}} {%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}}
{%- else -%}{{'</think>'}} {%- else -%}{{'<think></think>'}}
{%- endif -%} {%- endif -%}
{%- endif -%} {%- endif -%}
{%- set ns.is_last_user = false -%} {%- set ns.is_last_user = false -%}
@ -65,7 +65,7 @@
{%- endif -%} {%- endif -%}
{%- endfor -%} {%- endfor -%}
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<Assistant>'}} {%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<Assistant>'}}
{%- if not thinking -%}{{'</think>'}} {%- if not thinking -%}{{'<think></think>'}}
{%- else -%}{{'<think>'}} {%- else -%}{{'<think>'}}
{%- endif -%} {%- endif -%}
{%- endif %} {%- endif %}

View File

@ -49,7 +49,7 @@ Example function tool call syntax:
{%- endif -%} {%- endif -%}
{%- set tool_name = tc['function']['name'] -%} {%- set tool_name = tc['function']['name'] -%}
{%- set tool_args = tc['function']['arguments'] -%} {%- set tool_args = tc['function']['arguments'] -%}
{{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<tool▁call▁end>' -}} {{- '<tool▁call▁begin>' + tc['type'] + '<tool▁sep>' + tool_name + '\n' + '```json' + '\n' + tool_args | tojson + '\n' + '```' + '<tool▁call▁end>' -}}
{%- endfor -%} {%- endfor -%}
{{- '<tool▁calls▁end><end▁of▁sentence>' -}} {{- '<tool▁calls▁end><end▁of▁sentence>' -}}
{%- endif -%} {%- endif -%}

View File

@ -42,9 +42,9 @@
{%- if 'tool_calls' in message and message['tool_calls'] -%} {%- if 'tool_calls' in message and message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%} {%- for tool_call in message['tool_calls'] -%}
{%- if tool_call["function"]["name"] == "python" -%} {%- if tool_call["function"]["name"] == "python" -%}
{{ '<|python_tag|>' + tool_call['function']['arguments'] }} {{ '<|python_tag|>' + tool_call['function']['arguments'] | tojson }}
{%- else -%} {%- else -%}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] + '</function>' }} {{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] | tojson + '</function>' }}
{%- endif -%} {%- endif -%}
{%- endfor -%} {%- endfor -%}
{{ '<|eom_id|>' }} {{ '<|eom_id|>' }}

View File

@ -95,9 +95,9 @@ if __name__ == '__main__':
'-p', 'Hey', '-p', 'Hey',
'--no-warmup', '--no-warmup',
'--log-disable', '--log-disable',
'-no-cnv'] '-st']
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
cmd.append('-fa') cmd += ('-fa', 'on')
try: try:
subprocess.check_call(cmd) subprocess.check_call(cmd)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:

View File

@ -0,0 +1,157 @@
#!/usr/bin/env python3
import sys
from collections import defaultdict
def parse_log_file(filepath):
"""Parse log file and extract function VGPR usage."""
import re
functions = defaultdict(lambda: {'vgprs': 0, 'spill': 0, 'location': ''})
try:
with open(filepath, 'r') as f:
content = f.read()
# Find all function entries with VGPR usage including location
pattern = r'([^:]+:\d+):.*?Function Name: (\S+).*?VGPRs: (\d+).*?VGPRs Spill: (\d+)'
matches = re.findall(pattern, content, re.DOTALL)
for location, func_name, vgprs, spill in matches:
functions[func_name]['vgprs'] = int(vgprs)
functions[func_name]['spill'] = int(spill)
# Extract just the filename and line number
parts = location.split('/')
if len(parts) > 0:
short_location = parts[-1] # Get last part (filename)
# Check if there's a line number after filename
if ':' in short_location:
functions[func_name]['location'] = short_location
else:
functions[func_name]['location'] = location
else:
functions[func_name]['location'] = location
except FileNotFoundError:
print(f"Error: File {filepath} not found", file=sys.stderr) # noqa: NP100
sys.exit(1)
return functions
def main():
if len(sys.argv) < 2:
print("Usage: ./vgpr_check.py <log_file>", file=sys.stderr) # noqa: NP100
sys.exit(1)
log_file = sys.argv[1]
ignored = {
'_ZL21gated_linear_attn_f32ILi128EEviiiifPKfS1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL13rwkv_wkv7_f32ILi128EEviiiiPKfS1_S1_S1_S1_S1_S1_Pf',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi2ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi16ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi64ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi64ELi64ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi4ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi8ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type22ELi8ELb1EEvPKiS2_PfPKfiiimimimi',
'_ZL9mul_mat_qIL9ggml_type3ELi32ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi48ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type20ELi32ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi64ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi80ELi80ELi4ELi4ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL15flash_attn_tileILi256ELi256ELi32ELi1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type19ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type17ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type22ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type19ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type3ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type7ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type11ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL24mul_mat_q_stream_k_fixupIL9ggml_type11ELi128ELb0EEvPKiS2_PfPKfiiimimimi',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type2ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi32ELi2ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi112ELi112ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi1ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi32ELi2ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi128ELi128ELi4ELi8ELb1ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
}
functions = parse_log_file(log_file)
found_issues = False
# First print all ignored functions (deduplicated)
printed_ignored = set()
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name in ignored and func_name not in printed_ignored:
location = data.get('location', log_file)
print(f"{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) [IGNORED]") # noqa: NP100
printed_ignored.add(func_name)
# Then print new functions with issues in red
for func_name, data in sorted(functions.items()):
total_vgprs = int(data['vgprs']) + int(data['spill'])
if total_vgprs > 256 and func_name not in ignored:
status = "[IGNORED]" if func_name in ignored else ""
location = data.get('location', log_file)
# Print in red if not ignored
color_code = "\033[91m" if func_name not in ignored else ""
reset_code = "\033[0m" if func_name not in ignored else ""
print(f"{color_code}{location}: {func_name} - Total VGPRs: {total_vgprs} ({data['vgprs']} + {data['spill']}) {status}{reset_code}") # noqa: NP100
if func_name not in ignored:
found_issues = True
sys.exit(1 if found_issues else 0)
if __name__ == "__main__":
main()

View File

@ -39,6 +39,9 @@ opmask=
nhvx= nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev= ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" [ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \ cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \ LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \ $verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \ --ctx-size 8192 --ubatch-size 256 -fa on \

View File

@ -39,6 +39,9 @@ opmask=
nhvx= nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev= ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" [ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -51,7 +54,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \ cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \ LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \ $verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \ --ctx-size 8192 --ubatch-size 256 -fa on \

View File

@ -45,6 +45,9 @@ opmask=
nhvx= nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev= ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" [ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -58,7 +61,7 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \ cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \ LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \ $verbose $experimental $sched $opmask $profile $hmx $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \ ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \ --mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \ --image $basedir/../gguf/$image \

View File

@ -36,6 +36,9 @@ opmask=
nhvx= nhvx=
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" [ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
hmx=
[ "$HMX" != "" ] && hmx="GGML_HEXAGON_USE_HMX=$HMX"
ndev= ndev=
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" [ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
@ -50,5 +53,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \ cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \ LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb ./$branch/bin/$tool $@ \ $verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
" "

View File

@ -2129,19 +2129,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
throw std::runtime_error("cannot find tokenizer vocab in model file\n"); throw std::runtime_error("cannot find tokenizer vocab in model file\n");
} }
const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
const float * scores = nullptr; const float * scores = nullptr;
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
if (score_idx != -1) { if (score_idx != -1) {
const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx);
if (n_scores < n_tokens) {
throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n");
}
scores = (const float * ) gguf_get_arr_data(ctx, score_idx); scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
} }
const int * toktypes = nullptr; const int * toktypes = nullptr;
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
if (toktype_idx != -1) { if (toktype_idx != -1) {
const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx);
if (n_toktypes < n_tokens) {
throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n");
}
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
} }
uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
id_to_token.resize(n_tokens); id_to_token.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {

View File

@ -121,6 +121,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il); cb(cur, "l_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer // input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -111,8 +111,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_
} }
inpL = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
} }
cur = build_norm(inpL, cur = build_norm(inpL,

View File

@ -86,6 +86,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -82,6 +82,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
// input for next layer
inpL = cur; inpL = cur;
} }
cur = inpL; cur = inpL;

View File

@ -66,8 +66,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params
LLM_FFN_SILU, LLM_FFN_PAR, il); LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il); cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
} }
cur = build_norm(inpL, cur = build_norm(inpL,
model.output_norm, model.output_norm,

View File

@ -362,6 +362,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cur = build_cvec(cur, il); cur = build_cvec(cur, il);
cb(cur, "l_out", il); cb(cur, "l_out", il);
// input for next layer
inpL = cur; inpL = cur;
} }
cur = inpL; cur = inpL;

View File

@ -177,6 +177,9 @@ llm_build_lfm2<iswa>::llm_build_lfm2(const llama_model & model, const llm_graph_
cb(ffn_norm_out, "model.layers.{}.ffn_out", il); cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_out); cur = ggml_add(ctx0, cur, ffn_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
} }
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);

View File

@ -71,6 +71,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, residual); cur = ggml_add(ctx0, cur, residual);
cb(cur, "ffn_residual", il); cb(cur, "ffn_residual", il);
// input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -109,6 +109,8 @@ llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_gr
cur = build_cvec(cur, il); cur = build_cvec(cur, il);
cb(cur, "l_out", il); cb(cur, "l_out", il);
// input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -64,6 +64,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = ggml_add(ctx0, cur, ffn_residual); cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_ffn", il); cb(cur, "post_ffn", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer // Input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -64,6 +64,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_residual); cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_moe", il); cb(cur, "post_moe", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer // Input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -56,6 +56,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_residual); cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_moe", il); cb(cur, "post_moe", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// Input for next layer // Input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -101,6 +101,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model,
cur = ffn_out; cur = ffn_out;
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il); cur = build_cvec(cur, il);
cb(cur, "l_out", il); cb(cur, "l_out", il);

View File

@ -145,9 +145,11 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il); cur = build_cvec(cur, il);
cb(cur, "l_out", il); cb(cur, "l_out", il);
// input for next layer
inpL = cur; inpL = cur;
} }

View File

@ -1292,11 +1292,11 @@ static void test_nemotron_reasoning_detection(testing & t) {
// Check reasoning markers // Check reasoning markers
t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start); t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start);
t.assert_equal("reasoning_end should be '</think>\\n'", "</think>\n", analysis.reasoning.end); t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
// Check reasoning mode detection // Check reasoning mode detection
// Nemotron uses forced closed reasoning with add_generation_prompt // Nemotron uses tag-based reasoning; prefill handles the template's forced markers
t.assert_equal("reasoning should be FORCED_CLOSED", reasoning_mode::FORCED_CLOSED, analysis.reasoning.mode); t.assert_equal("reasoning should be TAG_BASED", reasoning_mode::TAG_BASED, analysis.reasoning.mode);
// Make sure reasoning markers don't spill over to content markers // Make sure reasoning markers don't spill over to content markers
t.assert_equal("content start should be empty", "", analysis.content.start); t.assert_equal("content start should be empty", "", analysis.content.start);

View File

@ -145,7 +145,7 @@ static void test_example_native(testing & t) {
common_reasoning_format reasoning_format; common_reasoning_format reasoning_format;
json json_schema; json json_schema;
bool parallel_tool_calls; bool parallel_tool_calls;
bool thinking_forced_open; std::string generation_prompt;
std::string input; std::string input;
// Expect // Expect
@ -157,14 +157,8 @@ static void test_example_native(testing & t) {
auto build_parser = [](const test_case & tc) { auto build_parser = [](const test_case & tc) {
return build_chat_peg_parser([&](common_chat_peg_builder & p) { return build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE); auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
auto reasoning = p.eps(); // Always use optional TAG_BASED pattern; generation_prompt is prepended to input
if (tc.thinking_forced_open) { auto reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
// If thinking is forced open, expect a closing tag
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
} else {
// Otherwise, optionally accept thinking wrapped in tags
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
}
// tool calling parser // tool calling parser
if (tc.tools.is_array() && !tc.tools.empty()) { if (tc.tools.is_array() && !tc.tools.empty()) {
@ -190,78 +184,91 @@ static void test_example_native(testing & t) {
std::vector<test_case> test_cases = std::vector<test_case>{ std::vector<test_case> test_cases = std::vector<test_case>{
{ {
/* .name = */ "content with thinking_forced_open = false", /* .name = */ "content with reasoning (no generation_prompt)",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false, /* .generation_prompt = */ "",
/* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"), /* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "The user said hello, I must say hello back", /* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello", /* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {}, /* .expect_tool_calls = */ {},
}, },
{ {
/* .name = */ "content with thinking_forced_open = false and no reasoning", /* .name = */ "content without reasoning (no generation_prompt)",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ false, /* .generation_prompt = */ "",
/* .input = */ ("Hello"), /* .input = */ ("Hello"),
/* .expect_reasoning = */ "", /* .expect_reasoning = */ "",
/* .expect_content = */ "Hello", /* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {}, /* .expect_tool_calls = */ {},
}, },
{ {
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none", /* .name = */ "content with reasoning_format = none (tags appear in content)",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "",
/* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"), /* .input = */ ("<think>The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "", /* .expect_reasoning = */ "",
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello", /* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {}, /* .expect_tool_calls = */ {},
}, },
{ {
/* .name = */ "content with thinking_forced_open = true", /* .name = */ "content with reasoning generation_prompt",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "<think>",
/* .input = */ ("The user said hello, I must say hello back</think>\nHello"), /* .input = */ ("The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "The user said hello, I must say hello back", /* .expect_reasoning = */ "The user said hello, I must say hello back",
/* .expect_content = */ "Hello", /* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {}, /* .expect_tool_calls = */ {},
}, },
{ {
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none", /* .name = */ "content with reasoning generation_prompt and reasoning_format = none",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "",
/* .input = */ ("The user said hello, I must say hello back</think>\nHello"), /* .input = */ ("The user said hello, I must say hello back</think>\nHello"),
/* .expect_reasoning = */ "", /* .expect_reasoning = */ "",
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello", /* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
/* .expect_tool_calls = */ {}, /* .expect_tool_calls = */ {},
}, },
{ {
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", /* .name = */ "content with closed reasoning generation_prompt (empty reasoning discarded)",
/* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {},
/* .parallel_tool_calls = */ false,
/* .generation_prompt = */ "<think></think>",
/* .input = */ ("Hello"),
/* .expect_reasoning = */ "",
/* .expect_content = */ "Hello",
/* .expect_tool_calls = */ {},
},
{
/* .name = */ "tools with reasoning generation_prompt",
/* .tools = */ create_tools(), /* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "<think>",
/* .input = */ /* .input = */
("I must get the weather in New York</think>\n" ("I must get the weather in New York</think>\n"
"<tool_call>[" "<tool_call>["
@ -277,13 +284,13 @@ static void test_example_native(testing & t) {
} }, } },
}, },
{ {
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls", /* .name = */ "parallel tools with reasoning generation_prompt",
/* .tools = */ create_tools(), /* .tools = */ create_tools(),
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .json_schema = */ {}, /* .json_schema = */ {},
/* .parallel_tool_calls = */ true, /* .parallel_tool_calls = */ true,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "<think>",
/* .input = */ /* .input = */
("I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me " ("I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me "
"search that for you." "search that for you."
@ -321,7 +328,7 @@ static void test_example_native(testing & t) {
} }, } },
}, },
{ {
/* .name = */ "response_format with thinking_forced_open = true", /* .name = */ "response_format with reasoning generation_prompt",
/* .tools = */ {}, /* .tools = */ {},
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
@ -333,7 +340,7 @@ static void test_example_native(testing & t) {
{ "due_date", { { "type", "string" } } } } }, { "due_date", { { "type", "string" } } } } },
{ "required", { "invoice_number", "amount", "due_date" } } }, { "required", { "invoice_number", "amount", "due_date" } } },
/* .parallel_tool_calls = */ false, /* .parallel_tool_calls = */ false,
/* .thinking_forced_open = */ true, /* .generation_prompt = */ "<think>",
/* .input = */ /* .input = */
("I must produce the invoice in the requested format</think>\n" ("I must produce the invoice in the requested format</think>\n"
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"), R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"),
@ -361,7 +368,8 @@ static void test_example_native(testing & t) {
t.log(line); t.log(line);
} }
common_peg_parse_context ctx(tc.input); std::string effective_input = tc.generation_prompt + tc.input;
common_peg_parse_context ctx(effective_input);
auto result = parser.parse(ctx); auto result = parser.parse(ctx);
t.assert_true("success", result.success()); t.assert_true("success", result.success());

View File

@ -822,8 +822,7 @@ struct make_peg_parser {
} }
common_chat_msg parse(const std::string & msg, bool is_partial) const { common_chat_msg parse(const std::string & msg, bool is_partial) const {
common_chat_parser_params parser_params; common_chat_parser_params parser_params(params_);
parser_params.format = params_.format;
parser_params.debug = detailed_debug_; parser_params.debug = detailed_debug_;
return common_chat_peg_parse(arena_, msg, is_partial, parser_params); return common_chat_peg_parse(arena_, msg, is_partial, parser_params);
} }
@ -996,6 +995,16 @@ static void test_peg_parser(common_chat_templates * tmpls,
grammar_triggered = true; grammar_triggered = true;
} }
// For non-lazy grammars, prepend reasoning prefill to grammar input, just like
// PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
// <think>...</think>), but the model output may start mid-reasoning if the template
// already placed the opening tag in the prompt.
// For lazy grammars, the grammar only activates from the trigger position, so the
// reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
constrained = parser.params_.generation_prompt + constrained;
}
// Test the constrained portion against the grammar // Test the constrained portion against the grammar
if (grammar_triggered && !tc.is_partial) { if (grammar_triggered && !tc.is_partial) {
auto result = match_string_detailed(constrained, grammar.get()); auto result = match_string_detailed(constrained, grammar.get());
@ -1271,11 +1280,13 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO) .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.expect(message_assist_thoughts) .expect(message_assist_thoughts)
.run(); .run();
tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})") tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO) .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool }) .tools({ special_function_tool })
.expect(message_assist_call) .expect(message_assist_call)
.run(); .run();
@ -1284,6 +1295,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"[THINK]I'm\nthinking[/THINK]" "[THINK]I'm\nthinking[/THINK]"
R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})") R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO) .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool }) .tools({ special_function_tool })
.expect(message_assist_call_thoughts) .expect(message_assist_call_thoughts)
.run(); .run();
@ -1317,12 +1329,15 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// NVIDIA Nemotron-3 Nano // NVIDIA Nemotron-3 Nano
auto tst = peg_tester("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", detailed_debug); auto tst = peg_tester("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").enable_thinking(false).expect(message_assist).run(); tst.test("Hello, world!\nWhat's up?").
enable_thinking(false).
reasoning_format(COMMON_REASONING_FORMAT_AUTO).
expect(message_assist).run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?") tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.enable_thinking(false) .enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_NONE) .reasoning_format(COMMON_REASONING_FORMAT_NONE)
.expect_content("I'm\nthinking\n</think>\nHello, world!\nWhat's up?") .expect_content("<think>I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
.run(); .run();
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?") tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
@ -1482,7 +1497,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect(simple_assist_msg("The answer is 42.", "Let me think about this...")) .expect(simple_assist_msg("The answer is 42.", "Let me think about this..."))
.run(); .run();
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run(); tst.test("</think>Hello, world!").reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(simple_assist_msg("Hello, world!")).run();
} }
{ {
// NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models) // NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models)
@ -1798,6 +1813,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": " "<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": "
"\"XYZCITY\"}<tool▁call▁end><tool▁calls▁end>") "\"XYZCITY\"}<tool▁call▁end><tool▁calls▁end>")
.tools({ get_time_tool }) .tools({ get_time_tool })
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run(); .run();
} }
@ -1843,7 +1860,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ {
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug);
tst.test("CONTENT").expect(simple_assist_msg("CONTENT", "")).run(); tst.test("CONTENT").enable_thinking(false).reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).
expect(simple_assist_msg("CONTENT", "")).run();
} }
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call> // GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
@ -1906,6 +1924,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<arg_key>arg1</arg_key><arg_value>1</arg_value>" "<arg_key>arg1</arg_key><arg_value>1</arg_value>"
"<arg_key>arg2</arg_key><arg_value>2</arg_value>" "<arg_key>arg2</arg_key><arg_value>2</arg_value>"
"</tool_call>") "</tool_call>")
.enable_thinking(false)
.parallel_tool_calls(true) .parallel_tool_calls(true)
.tools({ .tools({
special_function_tool, special_function_tool_with_optional_param special_function_tool, special_function_tool_with_optional_param
@ -1915,6 +1934,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
}) })
.run(); .run();
// #20650: tool with no required args, model emits <tool_call>name</tool_call> with no arg tags.
{
static common_chat_tool no_args_tool{
"read_file_diff_md", "Reads a file diff",
R"({"type":"object","properties":{"review_id":{"type":"string"},"file_id":{"type":"string"}}})",
};
tst.test(
"Let me read the diff content."
"</think>"
"<tool_call>read_file_diff_md</tool_call>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ no_args_tool })
.expect_reasoning("Let me read the diff content.")
.expect_tool_calls({{ "read_file_diff_md", "{}", {} }})
.run();
}
} }
// Kimi-K2-Thinking tests - custom parser // Kimi-K2-Thinking tests - custom parser
@ -2222,10 +2259,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ {
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug); auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
tst.test( tst.test(
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter " "</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>") "name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
.tools({ special_function_tool }) .tools({ special_function_tool })
.expect(message_assist_call) .expect(message_assist_call)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.run(); .run();
} }
@ -2288,8 +2326,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Functionary v3.2 - recipient-based format: >>>recipient\n{content} // Functionary v3.2 - recipient-based format: >>>recipient\n{content}
{ {
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug); auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug);
tst.test(">>>all\nHello, world!\nWhat's up?").expect(message_assist).run(); tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run();
tst.test(">>>special_function\n{\"arg1\": 1}") tst.test("special_function\n{\"arg1\": 1}")
.tools({ special_function_tool }) .tools({ special_function_tool })
.expect(message_assist_call) .expect(message_assist_call)
.run(); .run();
@ -2309,8 +2347,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag // Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag
{ {
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", detailed_debug); auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?") tst.test("</think>Hello, world!\nWhat's up?")
.enable_thinking(true) // Forced open .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist) .expect(message_assist)
.run(); .run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?") tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
@ -2322,14 +2360,15 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// llama-cpp DeepSeek R1 template (always forced-open thinking) // llama-cpp DeepSeek R1 template (always forced-open thinking)
{ {
auto tst = peg_tester("models/templates/llama-cpp-deepseek-r1.jinja", detailed_debug); auto tst = peg_tester("models/templates/llama-cpp-deepseek-r1.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); tst.test("</think>Hello, world!\nWhat's up?").expect(message_assist).reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?") tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts) .expect(message_assist_thoughts)
.run(); .run();
tst.test( tst.test(
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n" "</think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>") "```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ special_function_tool }) .tools({ special_function_tool })
.parallel_tool_calls(true) .parallel_tool_calls(true)
.expect(message_assist_call) .expect(message_assist_call)
@ -2339,7 +2378,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag // Note: Template uses forced-open mode (prompt ends with <think>), so input shouldn't include opening tag
{ {
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", detailed_debug); auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").enable_thinking(true).expect(message_assist).run(); tst.test("</think>Hello, world!\nWhat's up?").enable_thinking(true).
reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK).
expect(message_assist).run();
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?") tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts) .expect(message_assist_thoughts)
@ -2348,6 +2389,8 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n" "<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>") "```json\n{\"arg1\": 1}```<tool▁call▁end><tool▁calls▁end>")
.tools({ special_function_tool }) .tools({ special_function_tool })
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_call) .expect(message_assist_call)
.run(); .run();
} }
@ -2377,12 +2420,12 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Apriel 1.6 Thinker (reasoning-only support) // Apriel 1.6 Thinker (reasoning-only support)
{ {
auto tst = peg_tester("models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", detailed_debug); auto tst = peg_tester("models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Implicit reasoning start (forced open) // Implicit reasoning start (forced open)
tst.test("I'm\nthinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?") tst.test("I'm\nthinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO) .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist_thoughts) .enable_thinking(true)
.expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking"))
.run(); .run();
// Reasoning + Tool calls // Reasoning + Tool calls
@ -2390,8 +2433,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"I'm\nthinking\n[BEGIN FINAL RESPONSE]\n<tool_calls>[{\"name\": \"special_function\", \"arguments\": " "I'm\nthinking\n[BEGIN FINAL RESPONSE]\n<tool_calls>[{\"name\": \"special_function\", \"arguments\": "
"{\"arg1\": 1}}]</tool_calls>") "{\"arg1\": 1}}]</tool_calls>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO) .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool }) .tools({ special_function_tool })
.expect(message_assist_call_thoughts) .expect(simple_assist_msg("", "Here are my reasoning steps:\nI'm\nthinking", "special_function", "{\"arg1\":1}"))
.run(); .run();
} }

View File

@ -105,7 +105,7 @@ struct cli_context {
llama_get_model(ctx_server.get_llama_context())); llama_get_model(ctx_server.get_llama_context()));
task.params.sampling.reasoning_budget_tokens = reasoning_budget; task.params.sampling.reasoning_budget_tokens = reasoning_budget;
task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open; task.params.sampling.generation_prompt = chat_params.generation_prompt;
if (!chat_params.thinking_start_tag.empty()) { if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start = task.params.sampling.reasoning_budget_start =

View File

@ -41,6 +41,11 @@ struct clip_graph {
virtual ~clip_graph() = default; virtual ~clip_graph() = default;
virtual ggml_cgraph * build() = 0; virtual ggml_cgraph * build() = 0;
// wrapper around ggml_mul_mat, allow hooking (e.g. LoRA, clamping) depending on the model
// tensor w should be the weight matrix, and tensor x should be the input
virtual ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const;
// TODO: build_mm(w, b, x) to support bias
// //
// utility functions // utility functions
// //

View File

@ -255,6 +255,10 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
} }
ggml_tensor * clip_graph::build_mm(ggml_tensor * w, ggml_tensor * x) const {
return ggml_mul_mat(ctx0, w, x);
}
void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const { void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
if (il >= 0) { if (il >= 0) {
ggml_format_name(cur, "%s-%d", name, il); ggml_format_name(cur, "%s-%d", name, il);
@ -326,7 +330,7 @@ ggml_tensor * clip_graph::build_vit(
ggml_tensor * Vcur = nullptr; ggml_tensor * Vcur = nullptr;
if (layer.qkv_w != nullptr) { if (layer.qkv_w != nullptr) {
// fused qkv // fused qkv
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = build_mm(layer.qkv_w, cur);
if (layer.qkv_b != nullptr) { if (layer.qkv_b != nullptr) {
cur = ggml_add(ctx0, cur, layer.qkv_b); cur = ggml_add(ctx0, cur, layer.qkv_b);
} }
@ -360,17 +364,17 @@ ggml_tensor * clip_graph::build_vit(
} else { } else {
// separate q, k, v // separate q, k, v
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) { if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b); Qcur = ggml_add(ctx0, Qcur, layer.q_b);
} }
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) { if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b); Kcur = ggml_add(ctx0, Kcur, layer.k_b);
} }
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) { if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b); Vcur = ggml_add(ctx0, Vcur, layer.v_b);
} }
@ -517,7 +521,7 @@ ggml_tensor * clip_graph::build_ffn(
ffn_op_type type_op, ffn_op_type type_op,
int il) const { int il) const {
ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur; ggml_tensor * tmp = up ? build_mm(up, cur) : cur;
cb(tmp, "ffn_up", il); cb(tmp, "ffn_up", il);
if (up_b) { if (up_b) {
@ -526,7 +530,7 @@ ggml_tensor * clip_graph::build_ffn(
} }
if (gate) { if (gate) {
cur = ggml_mul_mat(ctx0, gate, cur); cur = build_mm(gate, cur);
cb(cur, "ffn_gate", il); cb(cur, "ffn_gate", il);
if (gate_b) { if (gate_b) {
@ -580,7 +584,7 @@ ggml_tensor * clip_graph::build_ffn(
} }
if (down) { if (down) {
cur = ggml_mul_mat(ctx0, down, cur); cur = build_mm(down, cur);
} }
if (down_b) { if (down_b) {
@ -646,7 +650,7 @@ ggml_tensor * clip_graph::build_attn(
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
if (wo) { if (wo) {
cur = ggml_mul_mat(ctx0, wo, cur); cur = build_mm(wo, cur);
} }
if (wo_b) { if (wo_b) {

View File

@ -19,7 +19,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
auto & layer = model.layers[il]; auto & layer = model.layers[il];
ggml_tensor * cur = inpL; ggml_tensor * cur = inpL;
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = build_mm(layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b); cur = ggml_add(ctx0, cur, layer.qkv_b);
@ -67,7 +67,7 @@ ggml_cgraph * clip_graph_cogvlm::build() {
ggml_row_size(inpL->type, n_embd), 0); ggml_row_size(inpL->type, n_embd), 0);
// Multiply with mm_model_proj // Multiply with mm_model_proj
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); cur = build_mm(model.mm_model_proj, cur);
// Apply layernorm, weight, bias // Apply layernorm, weight, bias
cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
@ -76,16 +76,16 @@ ggml_cgraph * clip_graph_cogvlm::build() {
cur = ggml_gelu_inplace(ctx0, cur); cur = ggml_gelu_inplace(ctx0, cur);
// Branch 1: multiply with mm_h_to_4h_w // Branch 1: multiply with mm_h_to_4h_w
ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur); ggml_tensor * h_to_4h = build_mm(model.mm_h_to_4h_w, cur);
// Branch 2: multiply with mm_gate_w // Branch 2: multiply with mm_gate_w
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur); ggml_tensor * gate = build_mm(model.mm_gate_w, cur);
// Apply silu // Apply silu
gate = ggml_swiglu_split(ctx0, gate, h_to_4h); gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
// Apply mm_4h_to_h_w // Apply mm_4h_to_h_w
cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate); cur = build_mm(model.mm_4h_to_h_w, gate);
// Concatenate with boi and eoi // Concatenate with boi and eoi
cur = ggml_concat(ctx0, model.mm_boi, cur, 1); cur = ggml_concat(ctx0, model.mm_boi, cur, 1);

View File

@ -56,7 +56,7 @@ ggml_cgraph * clip_graph_conformer::build() {
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// calculate out // calculate out
cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur); cur = build_mm(model.pre_encode_out_w, cur);
cur = ggml_add(ctx0, cur, model.pre_encode_out_b); cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
cb(cur, "conformer.pre_encode.out", -1); cb(cur, "conformer.pre_encode.out", -1);
} }
@ -87,7 +87,7 @@ ggml_cgraph * clip_graph_conformer::build() {
cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il); cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
cb(cur, "conformer.layers.{}.norm_self_att", il); cb(cur, "conformer.layers.{}.norm_self_att", il);
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); ggml_tensor * Qcur = build_mm(layer.q_w, cur);
Qcur = ggml_add(ctx0, Qcur, layer.q_b); Qcur = ggml_add(ctx0, Qcur, layer.q_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]); Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u); ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
@ -96,12 +96,12 @@ ggml_cgraph * clip_graph_conformer::build() {
Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3); Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
// TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases // TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); ggml_tensor * Kcur = build_mm(layer.k_w, cur);
Kcur = ggml_add(ctx0, Kcur, layer.k_b); Kcur = ggml_add(ctx0, Kcur, layer.k_b);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); ggml_tensor * Vcur = build_mm(layer.v_w, cur);
Vcur = ggml_add(ctx0, Vcur, layer.v_b); Vcur = ggml_add(ctx0, Vcur, layer.v_b);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]); Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
@ -111,7 +111,7 @@ ggml_cgraph * clip_graph_conformer::build() {
matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3)); matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il); cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb); auto * p = build_mm(layer.linear_pos_w, pos_emb);
cb(p, "conformer.layers.{}.self_attn.linear_pos", il); cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]); p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
p = ggml_permute(ctx0, p, 0, 2, 1, 3); p = ggml_permute(ctx0, p, 0, 2, 1, 3);
@ -143,7 +143,7 @@ ggml_cgraph * clip_graph_conformer::build() {
x = ggml_permute(ctx0, x, 2, 0, 1, 3); x = ggml_permute(ctx0, x, 2, 0, 1, 3);
x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]); x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x); ggml_tensor * out = build_mm(layer.o_w, x);
out = ggml_add(ctx0, out, layer.o_b); out = ggml_add(ctx0, out, layer.o_b);
cb(out, "conformer.layers.{}.self_attn.linear_out", il); cb(out, "conformer.layers.{}.self_attn.linear_out", il);
@ -157,7 +157,7 @@ ggml_cgraph * clip_graph_conformer::build() {
// conv // conv
{ {
auto * x = cur; auto * x = cur;
x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x); x = build_mm(layer.conv_pw1_w, x);
x = ggml_add(ctx0, x, layer.conv_pw1_b); x = ggml_add(ctx0, x, layer.conv_pw1_b);
cb(x, "conformer.layers.{}.conv.pointwise_conv1", il); cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
@ -181,7 +181,7 @@ ggml_cgraph * clip_graph_conformer::build() {
x = ggml_silu(ctx0, x); x = ggml_silu(ctx0, x);
// pointwise_conv2 // pointwise_conv2
x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x); x = build_mm(layer.conv_pw2_w, x);
x = ggml_add(ctx0, x, layer.conv_pw2_b); x = ggml_add(ctx0, x, layer.conv_pw2_b);
cur = x; cur = x;

View File

@ -97,7 +97,7 @@ ggml_cgraph * clip_graph_glm4v::build() {
// FC projector // FC projector
{ {
cur = ggml_mul_mat(ctx0, model.projection, cur); cur = build_mm(model.projection, cur);
// default LayerNorm (post_projection_norm) // default LayerNorm (post_projection_norm)
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
cur = ggml_gelu_erf(ctx0, cur); cur = ggml_gelu_erf(ctx0, cur);

View File

@ -22,7 +22,7 @@ ggml_cgraph * clip_graph_llama4::build() {
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
patch_size, patch_size, 3, n_embd); patch_size, patch_size, 3, n_embd);
inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); inp = build_mm(model.patch_embeddings_0, inp);
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
cb(inp, "patch_conv", -1); cb(inp, "patch_conv", -1);
} }
@ -78,15 +78,15 @@ ggml_cgraph * clip_graph_llama4::build() {
// based on Llama4VisionMLP2 (always uses GELU activation, no bias) // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
{ {
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); cur = build_mm(model.mm_model_mlp_1_w, cur);
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); cur = build_mm(model.mm_model_mlp_2_w, cur);
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
cb(cur, "adapter_mlp", -1); cb(cur, "adapter_mlp", -1);
} }
// Llama4MultiModalProjector // Llama4MultiModalProjector
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); cur = build_mm(model.mm_model_proj, cur);
cb(cur, "projected", -1); cb(cur, "projected", -1);
// build the graph // build the graph

View File

@ -70,17 +70,17 @@ ggml_cgraph * clip_graph_llava::build() {
// self-attention // self-attention
{ {
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); ggml_tensor * Qcur = build_mm(layer.q_w, cur);
if (layer.q_b) { if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b); Qcur = ggml_add(ctx0, Qcur, layer.q_b);
} }
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); ggml_tensor * Kcur = build_mm(layer.k_w, cur);
if (layer.k_b) { if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b); Kcur = ggml_add(ctx0, Kcur, layer.k_b);
} }
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); ggml_tensor * Vcur = build_mm(layer.v_w, cur);
if (layer.v_b) { if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b); Vcur = ggml_add(ctx0, Vcur, layer.v_b);
} }
@ -164,17 +164,17 @@ ggml_cgraph * clip_graph_llava::build() {
// llava projector // llava projector
if (proj_type == PROJECTOR_TYPE_MLP) { if (proj_type == PROJECTOR_TYPE_MLP) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = build_mm(model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_gelu(ctx0, embeddings);
if (model.mm_2_w) { if (model.mm_2_w) {
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = build_mm(model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
} }
} }
else if (proj_type == PROJECTOR_TYPE_MLP_NORM) { else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = build_mm(model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
// First LayerNorm // First LayerNorm
@ -186,7 +186,7 @@ ggml_cgraph * clip_graph_llava::build() {
embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_gelu(ctx0, embeddings);
// Second linear layer // Second linear layer
embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); embeddings = build_mm(model.mm_3_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
// Second LayerNorm // Second LayerNorm
@ -197,10 +197,10 @@ ggml_cgraph * clip_graph_llava::build() {
else if (proj_type == PROJECTOR_TYPE_LDP) { else if (proj_type == PROJECTOR_TYPE_LDP) {
// MobileVLM projector // MobileVLM projector
int n_patch = 24; int n_patch = 24;
ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); ggml_tensor * mlp_1 = build_mm(model.mm_model_mlp_1_w, embeddings);
mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b); mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
mlp_1 = ggml_gelu(ctx0, mlp_1); mlp_1 = ggml_gelu(ctx0, mlp_1);
ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1); ggml_tensor * mlp_3 = build_mm(model.mm_model_mlp_3_w, mlp_1);
mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b); mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
// mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1] // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
@ -229,10 +229,10 @@ ggml_cgraph * clip_graph_llava::build() {
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
// pointwise conv // pointwise conv
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1); block_1 = build_mm(model.mm_model_block_1_block_1_fc1_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
block_1 = ggml_relu(ctx0, block_1); block_1 = ggml_relu(ctx0, block_1);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1); block_1 = build_mm(model.mm_model_block_1_block_1_fc2_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
block_1 = ggml_hardsigmoid(ctx0, block_1); block_1 = ggml_hardsigmoid(ctx0, block_1);
// block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1] // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
@ -244,7 +244,7 @@ ggml_cgraph * clip_graph_llava::build() {
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1); block_1 = build_mm(model.mm_model_block_1_block_2_0_w, block_1);
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
// block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
@ -277,10 +277,10 @@ ggml_cgraph * clip_graph_llava::build() {
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
// pointwise conv // pointwise conv
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1); block_1 = build_mm(model.mm_model_block_2_block_1_fc1_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
block_1 = ggml_relu(ctx0, block_1); block_1 = ggml_relu(ctx0, block_1);
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1); block_1 = build_mm(model.mm_model_block_2_block_1_fc2_w, block_1);
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
block_1 = ggml_hardsigmoid(ctx0, block_1); block_1 = ggml_hardsigmoid(ctx0, block_1);
@ -292,7 +292,7 @@ ggml_cgraph * clip_graph_llava::build() {
block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1); block_1 = build_mm(model.mm_model_block_2_block_2_0_w, block_1);
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
@ -307,10 +307,10 @@ ggml_cgraph * clip_graph_llava::build() {
else if (proj_type == PROJECTOR_TYPE_LDPV2) else if (proj_type == PROJECTOR_TYPE_LDPV2)
{ {
int n_patch = 24; int n_patch = 24;
ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); ggml_tensor * mlp_0 = build_mm(model.mm_model_mlp_0_w, embeddings);
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
mlp_0 = ggml_gelu(ctx0, mlp_0); mlp_0 = ggml_gelu(ctx0, mlp_0);
ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); ggml_tensor * mlp_2 = build_mm(model.mm_model_mlp_2_w, mlp_0);
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
// mlp_2 ne = [2048, 576, 1, 1] // mlp_2 ne = [2048, 576, 1, 1]
// // AVG Pool Layer 2*2, strides = 2 // // AVG Pool Layer 2*2, strides = 2
@ -344,15 +344,15 @@ ggml_cgraph * clip_graph_llava::build() {
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
// GLU // GLU
{ {
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); embeddings = build_mm(model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings); embeddings = ggml_gelu_inplace(ctx0, embeddings);
ggml_tensor * x = embeddings; ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); embeddings = build_mm(model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); x = build_mm(model.mm_model_mlp_1_w,x);
embeddings = ggml_swiglu_split(ctx0, embeddings, x); embeddings = ggml_swiglu_split(ctx0, embeddings, x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); embeddings = build_mm(model.mm_model_mlp_3_w, embeddings);
} }
// arrangement of BOI/EOI token embeddings // arrangement of BOI/EOI token embeddings
// note: these embeddings are not present in text model, hence we cannot process them as text tokens // note: these embeddings are not present in text model, hence we cannot process them as text tokens

View File

@ -38,7 +38,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
// resampler projector (it is just another transformer) // resampler projector (it is just another transformer)
ggml_tensor * q = model.mm_model_query; ggml_tensor * q = model.mm_model_query;
ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); ggml_tensor * v = build_mm(model.mm_model_kv_proj, embeddings);
// norm // norm
q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1); q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
@ -77,13 +77,13 @@ ggml_cgraph * clip_graph_minicpmv::build() {
// Use actual config value if available, otherwise fall back to hardcoded values // Use actual config value if available, otherwise fall back to hardcoded values
int num_query = hparams.minicpmv_query_num; int num_query = hparams.minicpmv_query_num;
ggml_tensor * Q = ggml_add(ctx0, ggml_tensor * Q = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), build_mm(model.mm_model_attn_q_w, q),
model.mm_model_attn_q_b); model.mm_model_attn_q_b);
ggml_tensor * K = ggml_add(ctx0, ggml_tensor * K = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), build_mm(model.mm_model_attn_k_w, k),
model.mm_model_attn_k_b); model.mm_model_attn_k_b);
ggml_tensor * V = ggml_add(ctx0, ggml_tensor * V = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), build_mm(model.mm_model_attn_v_w, v),
model.mm_model_attn_v_b); model.mm_model_attn_v_b);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query); Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
@ -105,7 +105,7 @@ ggml_cgraph * clip_graph_minicpmv::build() {
embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1); embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
// projection // projection
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); embeddings = build_mm(model.mm_model_proj, embeddings);
// build the graph // build the graph
ggml_build_forward_expand(gf, embeddings); ggml_build_forward_expand(gf, embeddings);

View File

@ -429,7 +429,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() {
// PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False) // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
// Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size] // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
if (model.mm_input_proj_w) { if (model.mm_input_proj_w) {
cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); cur = build_mm(model.mm_input_proj_w, cur);
} }
// 5. POST PROJECTION NORM // 5. POST PROJECTION NORM

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_pixtral::build() {
// project to n_embd // project to n_embd
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); cur = build_mm(model.mm_patch_merger_w, cur);
} }
// LlavaMultiModalProjector (always using GELU activation) // LlavaMultiModalProjector (always using GELU activation)

View File

@ -90,11 +90,11 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
// self-attention // self-attention
{ {
ggml_tensor * Qcur = ggml_add(ctx0, ggml_tensor * Qcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); build_mm(layer.q_w, cur), layer.q_b);
ggml_tensor * Kcur = ggml_add(ctx0, ggml_tensor * Kcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); build_mm(layer.k_w, cur), layer.k_b);
ggml_tensor * Vcur = ggml_add(ctx0, ggml_tensor * Vcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); build_mm(layer.v_w, cur), layer.v_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);

View File

@ -85,7 +85,7 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
// self-attention // self-attention
{ {
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = build_mm(layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b); cur = ggml_add(ctx0, cur, layer.qkv_b);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() {
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.n_merge; const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor); cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur); cur = build_mm(model.projection, cur);
} else if (proj_type == PROJECTOR_TYPE_LFM2) { } else if (proj_type == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block // pixel unshuffle block

View File

@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in // ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); cur = build_mm(model.mm_1_w, cur);
// swiglu // swiglu
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
@ -70,11 +70,11 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out // ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); cur = build_mm(model.mm_2_w, cur);
} else if (proj_type == PROJECTOR_TYPE_QWEN2A) { } else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
// projector // projector
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = build_mm(model.mm_fc_w, cur);
cur = ggml_add(ctx0, cur, model.mm_fc_b); cur = ggml_add(ctx0, cur, model.mm_fc_b);
} else if (proj_type == PROJECTOR_TYPE_VOXTRAL) { } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) {

View File

@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_youtuvl::build() {
ctx0, inp, ctx0, inp,
3*patch_size* patch_size, Hm * Wm * m * m, 1); 3*patch_size* patch_size, Hm * Wm * m * m, 1);
} }
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); inp = build_mm(model.patch_embeddings_0, inp);
if (model.patch_bias) { if (model.patch_bias) {
inp = ggml_add(ctx0, inp, model.patch_bias); inp = ggml_add(ctx0, inp, model.patch_bias);
@ -97,11 +97,11 @@ ggml_cgraph * clip_graph_youtuvl::build() {
// self-attention // self-attention
{ {
ggml_tensor * Qcur = ggml_add(ctx0, ggml_tensor * Qcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); build_mm(layer.q_w, cur), layer.q_b);
ggml_tensor * Kcur = ggml_add(ctx0, ggml_tensor * Kcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); build_mm(layer.k_w, cur), layer.k_b);
ggml_tensor * Vcur = ggml_add(ctx0, ggml_tensor * Vcur = ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); build_mm(layer.v_w, cur), layer.v_b);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);

View File

@ -282,7 +282,7 @@ static void render_scenario(const common_chat_template & tmpl,
LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str()); LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str());
try { try {
autoparser::templates_params inputs; autoparser::generation_params inputs;
inputs.messages = final_messages; inputs.messages = final_messages;
inputs.add_generation_prompt = add_generation_prompt; inputs.add_generation_prompt = add_generation_prompt;
inputs.extra_context["enable_thinking"] = enable_thinking; inputs.extra_context["enable_thinking"] = enable_thinking;
@ -395,7 +395,7 @@ int main(int argc, char ** argv) {
analysis.analyze_template(chat_template); analysis.analyze_template(chat_template);
// Generate Parser // Generate Parser
autoparser::templates_params params; autoparser::generation_params params;
params.messages = json::array({ build_user_message() }); params.messages = json::array({ build_user_message() });
params.reasoning_format = params.reasoning_format =
opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE;

View File

@ -400,12 +400,12 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_no_tools; autoparser::generation_params params_no_tools;
params_no_tools.messages = json::array({ user_msg }); params_no_tools.messages = json::array({ user_msg });
params_no_tools.add_generation_prompt = false; params_no_tools.add_generation_prompt = false;
params_no_tools.tools = json::array(); params_no_tools.tools = json::array();
autoparser::templates_params params_with_tools = params_no_tools; autoparser::generation_params params_with_tools = params_no_tools;
params_with_tools.tools = tools; params_with_tools.tools = tools;
std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools); std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools);
@ -419,12 +419,12 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_no_prompt; autoparser::generation_params params_no_prompt;
params_no_prompt.messages = json::array({ user_msg }); params_no_prompt.messages = json::array({ user_msg });
params_no_prompt.add_generation_prompt = false; params_no_prompt.add_generation_prompt = false;
params_no_prompt.tools = json::array(); params_no_prompt.tools = json::array();
autoparser::templates_params params_with_prompt = params_no_prompt; autoparser::generation_params params_with_prompt = params_no_prompt;
params_with_prompt.add_generation_prompt = true; params_with_prompt.add_generation_prompt = true;
std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt); std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt);
@ -438,12 +438,12 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_no_reasoning; autoparser::generation_params params_no_reasoning;
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() }); params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() });
params_no_reasoning.add_generation_prompt = false; params_no_reasoning.add_generation_prompt = false;
params_no_reasoning.enable_thinking = true; params_no_reasoning.enable_thinking = true;
autoparser::templates_params params_with_reasoning = params_no_reasoning; autoparser::generation_params params_with_reasoning = params_no_reasoning;
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() }); params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() });
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
@ -458,12 +458,12 @@ static void analyze_template(const std::string & template_path) {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
json user_msg2 = make_user_msg2(); json user_msg2 = make_user_msg2();
autoparser::templates_params params_no_reasoning; autoparser::generation_params params_no_reasoning;
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 }); params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 });
params_no_reasoning.add_generation_prompt = false; params_no_reasoning.add_generation_prompt = false;
params_no_reasoning.enable_thinking = true; params_no_reasoning.enable_thinking = true;
autoparser::templates_params params_with_reasoning = params_no_reasoning; autoparser::generation_params params_with_reasoning = params_no_reasoning;
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 }); params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 });
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
@ -477,12 +477,12 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_no_tool; autoparser::generation_params params_no_tool;
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() }); params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() });
params_no_tool.add_generation_prompt = false; params_no_tool.add_generation_prompt = false;
params_no_tool.tools = tools; params_no_tool.tools = tools;
autoparser::templates_params params_with_tool = params_no_tool; autoparser::generation_params params_with_tool = params_no_tool;
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
@ -497,12 +497,12 @@ static void analyze_template(const std::string & template_path) {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
json user_msg2 = make_user_msg2_continue(); json user_msg2 = make_user_msg2_continue();
autoparser::templates_params params_no_tool; autoparser::generation_params params_no_tool;
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 }); params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 });
params_no_tool.add_generation_prompt = false; params_no_tool.add_generation_prompt = false;
params_no_tool.tools = tools; params_no_tool.tools = tools;
autoparser::templates_params params_with_tool = params_no_tool; autoparser::generation_params params_with_tool = params_no_tool;
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
@ -516,12 +516,12 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_one_tool; autoparser::generation_params params_one_tool;
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
params_one_tool.add_generation_prompt = false; params_one_tool.add_generation_prompt = false;
params_one_tool.tools = tools; params_one_tool.tools = tools;
autoparser::templates_params params_two_tools = params_one_tool; autoparser::generation_params params_two_tools = params_one_tool;
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() }); params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() });
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
@ -536,12 +536,12 @@ static void analyze_template(const std::string & template_path) {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
json user_msg2 = make_user_msg2_continue(); json user_msg2 = make_user_msg2_continue();
autoparser::templates_params params_one_tool; autoparser::generation_params params_one_tool;
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
params_one_tool.add_generation_prompt = false; params_one_tool.add_generation_prompt = false;
params_one_tool.tools = tools; params_one_tool.tools = tools;
autoparser::templates_params params_two_tools = params_one_tool; autoparser::generation_params params_two_tools = params_one_tool;
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 }); params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 });
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
@ -555,13 +555,13 @@ static void analyze_template(const std::string & template_path) {
{ {
json user_msg = make_user_msg(); json user_msg = make_user_msg();
autoparser::templates_params params_no_reasoning; autoparser::generation_params params_no_reasoning;
params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() }); params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() });
params_no_reasoning.add_generation_prompt = false; params_no_reasoning.add_generation_prompt = false;
params_no_reasoning.tools = tools; params_no_reasoning.tools = tools;
params_no_reasoning.enable_thinking = true; params_no_reasoning.enable_thinking = true;
autoparser::templates_params params_with_reasoning = params_no_reasoning; autoparser::generation_params params_with_reasoning = params_no_reasoning;
params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() }); params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() });
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);

View File

@ -215,7 +215,9 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--models-autoload, --no-models-autoload` | for router server, whether to automatically load models (default: enabled)<br/>(env: LLAMA_ARG_MODELS_AUTOLOAD) | | `--models-autoload, --no-models-autoload` | for router server, whether to automatically load models (default: enabled)<br/>(env: LLAMA_ARG_MODELS_AUTOLOAD) |
| `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: enabled)<br/>(env: LLAMA_ARG_JINJA) | | `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: enabled)<br/>(env: LLAMA_ARG_JINJA) |
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) | | `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) | | `-rea, --resoning [on\|off\|auto]` | Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))<br/>(env: LLAMA_ARG_REASONING) |
| `--reasoning-budget N` | token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
| `--reasoning-budget-message MESSAGE` | message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)<br/>(env: LLAMA_ARG_THINK_BUDGET_MESSAGE) |
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
| `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_PREFILL_ASSISTANT) | | `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_PREFILL_ASSISTANT) |
@ -907,7 +909,7 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat
"chat_format": "GPT-OSS", "chat_format": "GPT-OSS",
"reasoning_format": "none", "reasoning_format": "none",
"reasoning_in_content": false, "reasoning_in_content": false,
"thinking_forced_open": false, "generation_prompt": "",
"samplers": [ "samplers": [
"penalties", "penalties",
"dry", "dry",
@ -972,7 +974,7 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat
"chat_format": "GPT-OSS", "chat_format": "GPT-OSS",
"reasoning_format": "none", "reasoning_format": "none",
"reasoning_in_content": false, "reasoning_in_content": false,
"thinking_forced_open": false, "generation_prompt": "",
"samplers": [ "samplers": [
"penalties", "penalties",
"dry", "dry",
@ -1193,7 +1195,7 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
`reasoning_format`: The reasoning format to be parsed. If set to `none`, it will output the raw generated text. `reasoning_format`: The reasoning format to be parsed. If set to `none`, it will output the raw generated text.
`thinking_forced_open`: Force a reasoning model to always output the reasoning. Only works on certain models. `generation_prompt`: The generation prompt that was prefilled in by the template. Prepended to model output before parsing.
`parse_tool_calls`: Whether to parse the generated tool call. `parse_tool_calls`: Whether to parse the generated tool call.

Binary file not shown.

View File

@ -1081,20 +1081,21 @@ json oaicompat_chat_params_parse(
} }
} }
llama_params["chat_format"] = static_cast<int>(chat_params.format); llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt; llama_params["prompt"] = chat_params.prompt;
if (!chat_params.grammar.empty()) { if (!chat_params.grammar.empty()) {
llama_params["grammar"] = chat_params.grammar; llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_type"] = std::string("tool_calls");
} }
llama_params["grammar_lazy"] = chat_params.grammar_lazy; llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array(); auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) { for (const auto & trigger : chat_params.grammar_triggers) {
server_grammar_trigger ct(trigger); server_grammar_trigger ct(trigger);
grammar_triggers.push_back(ct.to_json()); grammar_triggers.push_back(ct.to_json());
} }
llama_params["grammar_triggers"] = grammar_triggers; llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens; llama_params["preserved_tokens"] = chat_params.preserved_tokens;
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; llama_params["generation_prompt"] = chat_params.generation_prompt;
for (const auto & stop : chat_params.additional_stops) { for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
@ -1114,7 +1115,6 @@ json oaicompat_chat_params_parse(
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag; llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag; llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message; llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
} }
} }

View File

@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <cinttypes> #include <cinttypes>
#include <exception>
#include <memory> #include <memory>
#include <filesystem> #include <filesystem>
@ -1152,11 +1153,11 @@ private:
// initialize samplers // initialize samplers
if (task.need_sampling()) { if (task.need_sampling()) {
slot.smpl.reset(common_sampler_init(model, task.params.sampling)); try {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
if (slot.smpl == nullptr) { } catch (std::exception & e) {
// for now, the only error that may happen here is invalid grammar std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
@ -1431,9 +1432,10 @@ private:
res->tokens = { tkn.tok }; res->tokens = { tkn.tok };
} }
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.task->n_tokens(); res->n_prompt_tokens = slot.task->n_tokens();
res->post_sampling_probs = slot.task->params.post_sampling_probs; res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose; res->verbose = slot.task->params.verbose;
res->res_type = slot.task->params.res_type; res->res_type = slot.task->params.res_type;
@ -1478,14 +1480,15 @@ private:
res->prompt = slot.task->tokens.detokenize(ctx, true); res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields); res->response_fields = std::move(slot.task->params.response_fields);
res->truncated = slot.truncated; res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.task->n_tokens(); res->n_prompt_tokens = slot.task->n_tokens();
res->n_tokens_cached = slot.prompt.n_tokens(); res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
res->has_new_line = slot.has_new_line; res->n_tokens_cached = slot.prompt.n_tokens();
res->stopping_word = slot.stopping_word; res->has_new_line = slot.has_new_line;
res->stop = slot.stop; res->stopping_word = slot.stopping_word;
res->post_sampling_probs = slot.task->params.post_sampling_probs; res->stop = slot.stop;
res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose; res->verbose = slot.task->params.verbose;
res->stream = slot.task->params.stream; res->stream = slot.task->params.stream;

Some files were not shown because too many files have changed in this diff Show More