Merge branch 'ggml-org:master' into i8mm-ci

This commit is contained in:
Rohanjames1997 2026-03-09 11:35:30 -05:00 committed by GitHub
commit a3ccf2fb9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
236 changed files with 38946 additions and 26801 deletions

View File

@ -93,7 +93,7 @@ jobs:
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
ctest -L main -E "test-llama-archs" --verbose --timeout 900
macOS-latest-cmake-x64:
runs-on: macos-15-intel

View File

@ -11,6 +11,8 @@
/common/base64.hpp.* @ggerganov
/common/build-info.* @ggerganov
/common/chat.* @pwilkin
/common/chat-auto*.* @pwilkin
/common/chat-diff-analyzer.* @pwilkin
/common/chat-peg-parser.* @aldehir
/common/common.* @ggerganov
/common/console.* @ggerganov
@ -89,12 +91,13 @@
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-chat-.* @pwilkin
/tests/test-chat.* @pwilkin
/tools/batched-bench/ @ggerganov
/tools/cli/ @ngxson
/tools/completion/ @ggerganov
/tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov
/tools/parser/ @pwilkin
/tools/quantize/ @ggerganov
/tools/rpc/ @rgerganov
/tools/server/* @ngxson @ggerganov # no subdir

View File

@ -39,6 +39,7 @@ Before submitting your PR:
- For intricate features, consider opening a feature request first to discuss and align expectations
- When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
- If you are a new contributor, limit your open PRs to 1.
After submitting your PR:
- Expect requests for modifications to ensure the code meets llama.cpp's standards for quality and long-term maintainability

View File

@ -259,6 +259,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale
- [llmaz](https://github.com/InftyAI/llmaz) - ☸️ Easy, advanced inference platform for large language models on Kubernetes.
- [LLMKube](https://github.com/defilantech/llmkube) - Kubernetes operator for llama.cpp with multi-GPU and Apple Silicon Metal
support"
</details>
<details>

View File

@ -47,10 +47,10 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-parser.cpp
chat-parser.h
chat-parser-xml-toolcall.h
chat-parser-xml-toolcall.cpp
chat-auto-parser-generator.cpp
chat-auto-parser-helpers.cpp
chat-auto-parser.h
chat-diff-analyzer.cpp
chat-peg-parser.cpp
chat-peg-parser.h
chat.cpp

View File

@ -2666,7 +2666,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.out_file = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RESULTS}));
add_opt(common_arg(
{"-ofreq", "--output-frequency"}, "N",
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@ -3607,6 +3607,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"--check"},
string_format("check rather than generate results (default: %s)", params.check ? "true" : "false"),
[](common_params & params) {
params.check = true;
}
).set_examples({LLAMA_EXAMPLE_RESULTS}));
add_opt(common_arg(
{"--save-logits"},
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),

View File

@ -0,0 +1,448 @@
#include "chat-auto-parser.h"
#include "chat-peg-parser.h"
#include "chat.h"
#include "common.h"
#include "json-schema-to-grammar.h"
#include "nlohmann/json.hpp"
#include <stdexcept>
#include <string>
using json = nlohmann::ordered_json;
// Helper to iterate over tools/functions
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
continue;
}
fn(tool);
}
}
namespace autoparser {
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
p(p),
inputs(inputs),
reasoning_parser(p.eps()) {}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs) {
// Run differential analysis to extract template structure
struct autoparser autoparser;
autoparser.analyze_template(tmpl);
return generate_parser(tmpl, inputs, autoparser);
}
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const autoparser & autoparser) {
// Build the parser using the analysis results
auto parser = autoparser.build_parser(inputs);
// Create the result structure
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens;
data.parser = parser.save();
// Build grammar if tools are present
bool has_tools =
autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty();
std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start :
autoparser.tools.format.per_call_start;
bool has_response_format = !inputs.json_schema.empty() && inputs.json_schema.is_object();
bool include_grammar = has_response_format || (has_tools &&
((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
if (include_grammar) {
data.grammar_lazy = !has_response_format && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
// Set grammar triggers based on tool section markers (fall back to per-call markers)
if (data.grammar_lazy) {
data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker }
};
}
}
return data;
}
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
if (!analysis_complete) {
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
}
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
// If the template uses Python dict format (single-quoted strings in JSON structures),
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
}
parser_build_context ctx(p, inputs);
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = inputs.enable_thinking;
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content;
// Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx);
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
if (has_response_format) {
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
return ctx.reasoning_parser + p.space() + p.choice({
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
response_format
}) + p.end();
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
return tools.build_parser(ctx);
}
return content.build_parser(ctx);
});
}
common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) const {
auto & p = ctx.p;
if (!ctx.extracting_reasoning) {
return p.eps();
}
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
return p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
// Both use the same tag-based pattern if markers are available
if (!start.empty() && !end.empty()) {
return p.optional(start + p.reasoning(p.until(end)) + end);
}
} else if (mode == reasoning_mode::DELIMITER) {
return p.optional(p.reasoning(p.until(end)) + end);
}
return p.eps();
}
common_peg_parser analyze_content::build_parser(parser_build_context & ctx) const {
auto & p = ctx.p;
if (is_always_wrapped()) {
if (ctx.extracting_reasoning) {
return ctx.reasoning_parser + start + p.content(p.until(end)) + end + p.end();
}
return p.content(p.until(start)) + start + p.content(p.until(end)) + end + p.end();
}
return ctx.reasoning_parser + p.content(p.rest()) + p.end();
}
common_peg_parser analyze_content::build_optional_wrapped(parser_build_context & ctx) const {
auto & p = ctx.p;
if (is_always_wrapped()) {
return p.optional(start + p.content(p.until(end)) + end);
}
return p.eps();
}
common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const {
switch (format.mode) {
case tool_format::JSON_NATIVE:
return build_tool_parser_json_native(ctx);
case tool_format::TAG_WITH_JSON:
return build_tool_parser_tag_json(ctx);
case tool_format::TAG_WITH_TAGGED:
return build_tool_parser_tag_tagged(ctx);
default:
GGML_ABORT("Unable to create tool parser");
}
}
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
auto & p = ctx.p;
const auto & inputs = ctx.inputs;
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
// Build effective field names with dot notation if function_field is set
std::string name_field = format.name_field;
std::string args_field = format.args_field;
if (!format.function_field.empty() && format.function_field != "function" &&
name_field.find('.') == std::string::npos) {
name_field = format.function_field + "." + name_field;
args_field = format.function_field + "." + args_field;
}
auto tools_parser = p.standard_json_tools(
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
// Handle content wrappers if present
if (ctx.content && ctx.content->is_always_wrapped()) {
auto wrapped_content = ctx.content->build_optional_wrapped(ctx);
return ctx.reasoning_parser + wrapped_content + tools_parser + p.end();
}
std::string tool_start = "{";
if (!format.section_start.empty()) {
tool_start = format.section_start;
} else if (!format.per_call_start.empty()) {
tool_start = format.per_call_start;
}
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
p.end();
}
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
auto & p = ctx.p;
const auto & inputs = ctx.inputs;
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
common_peg_parser tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & func = tool.at("function");
std::string name = func.at("name");
const auto & schema = func.at("parameters");
// Build call_id parser based on position (if supported)
common_peg_parser call_id_section = p.eps();
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
!call_id.suffix.empty()) {
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
}
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
if (!function.close.empty()) {
func_parser = func_parser + function.close;
}
tool_choice |= p.rule("tool-" + name, func_parser);
});
auto require_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
common_peg_parser tool_calls = p.eps();
if (!format.per_call_start.empty()) {
auto wrapped_call = format.per_call_start + tool_choice + format.per_call_end;
if (inputs.parallel_tool_calls) {
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
} else {
tool_calls = p.trigger_rule("tool-call", wrapped_call);
}
if (!format.section_start.empty()) {
tool_calls = p.trigger_rule("tool-calls",
p.literal(format.section_start) + p.space() + tool_calls + p.space() +
(format.section_end.empty() ? p.end() : p.literal(format.section_end)));
}
} else {
std::string separator = ", "; // Default
if (inputs.parallel_tool_calls) {
tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice +
p.zero_or_more(separator + tool_choice) + format.section_end);
} else {
tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + format.section_end);
}
}
if (!require_calls) {
tool_calls = p.optional(tool_calls);
}
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
p.end();
}
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
auto & p = ctx.p;
const auto & inputs = ctx.inputs;
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
common_peg_parser tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & func = tool.at("function");
std::string name = func.at("name");
const auto & params = func.at("parameters");
if (!params.contains("properties") || !params.at("properties").is_object()) {
return;
}
const auto & properties = params.at("properties");
std::set<std::string> required;
if (params.contains("required") && params.at("required").is_array()) {
params.at("required").get_to(required);
}
// Build parser for each argument, separating required and optional
std::vector<common_peg_parser> required_parsers;
std::vector<common_peg_parser> optional_parsers;
for (const auto & [param_name, param_schema] : properties.items()) {
bool is_required = required.find(param_name) != required.end();
std::string type = "object";
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
if (type_obj.is_string()) {
type_obj.get_to(type);
} else if (type_obj.is_object()) {
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
type_obj.at("type").get_to(type);
}
}
auto arg = p.tool_arg(
p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
arguments.name_suffix) +
arguments.value_prefix +
(type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
p.space()) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
required_parsers.push_back(named_arg);
} else {
optional_parsers.push_back(named_arg);
}
}
// Build required arg sequence in definition order
common_peg_parser args_seq = p.eps();
for (size_t i = 0; i < required_parsers.size(); i++) {
if (i > 0) {
args_seq = args_seq + p.space();
}
args_seq = args_seq + required_parsers[i];
}
// Build optional args with flexible ordering
if (!optional_parsers.empty()) {
common_peg_parser any_opt = p.choice();
for (const auto & opt : optional_parsers) {
any_opt |= opt;
}
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
}
// Build call_id parser based on position (if supported)
common_peg_parser call_id_section = p.eps();
bool have_call_id = false;
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
!call_id.suffix.empty()) {
have_call_id = true;
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
}
bool matched_atomic = false;
common_peg_parser func_parser = p.eps();
if (!function.name_suffix.empty()) {
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + args_seq;
matched_atomic = true;
} else if (have_call_id) {
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section) + p.space() + args_seq;
matched_atomic = true;
} else if (!arguments.name_prefix.empty() && properties.size() > 0) {
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
matched_atomic = true;
} else {
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
call_id_section + p.space() + args_seq;
}
if (!function.close.empty()) {
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
} else if (!format.per_call_end.empty()) {
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
// we only emit tool_close when we can actually see the closing marker. This prevents
// premature closing during partial parsing when we've seen e.g. "</" which could be
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
} else {
func_parser =
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
}
if (!matched_atomic) {
func_parser = p.atomic(func_parser);
}
tool_choice |= p.rule("tool-" + name, func_parser);
});
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
common_peg_parser tool_calls = p.eps();
if (!format.per_call_start.empty()) {
auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end;
if (inputs.parallel_tool_calls) {
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
} else {
tool_calls = p.trigger_rule("tool-call", wrapped_call);
}
if (!format.section_start.empty()) {
tool_calls = p.trigger_rule("tool-calls",
p.literal(format.section_start) + p.space() + tool_calls + p.space() +
(format.section_end.empty() ? p.end() : p.literal(format.section_end)));
}
} else {
std::string separator = ", "; // Default
if (inputs.parallel_tool_calls) {
tool_calls = p.trigger_rule("tool-call", format.section_start + p.space() + tool_choice +
p.zero_or_more(separator + tool_choice) + p.space() +
format.section_end);
} else {
tool_calls = p.trigger_rule(
"tool-call", format.section_start + p.space() + tool_choice + p.space() + format.section_end);
}
}
if (!require_tools) {
tool_calls = p.optional(tool_calls);
}
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
p.end();
}
} // namespace autoparser

View File

@ -0,0 +1,347 @@
#include "chat-auto-parser-helpers.h"
#include "chat-auto-parser.h"
#include "chat.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include <cctype>
#include <numeric>
using json = nlohmann::ordered_json;
std::string trim_whitespace(const std::string & str) {
size_t start = 0;
while (start < str.length() && std::isspace(static_cast<unsigned char>(str[start]))) {
start++;
}
if (start == str.length()) {
return "";
}
size_t end = str.length() - 1;
while (end > start && std::isspace(static_cast<unsigned char>(str[end]))) {
end--;
}
return str.substr(start, end - start + 1);
}
std::string trim_leading_whitespace(const std::string & str) {
size_t start = 0;
while (start < str.length() && std::isspace(static_cast<unsigned char>(str[start]))) {
start++;
}
return str.substr(start);
}
std::string trim_trailing_whitespace(const std::string & str) {
if (str.empty()) {
return "";
}
size_t end = str.length() - 1;
while (end > 0 && std::isspace(static_cast<unsigned char>(str[end]))) {
end--;
}
// If first char is also whitespace, return empty string
if (end == 0 && std::isspace(static_cast<unsigned char>(str[0]))) {
return "";
}
return str.substr(0, end + 1);
}
std::string trim_trailing_newlines(const std::string & str) {
size_t end = str.length();
while (end > 0 && str[end - 1] == '\n') {
end--;
}
return str.substr(0, end);
}
static size_t common_prefix_len(const std::string & left, const std::string & right) {
size_t prefix_len = 0;
size_t min_len = std::min(left.length(), right.length());
while (prefix_len < min_len && left[prefix_len] == right[prefix_len]) {
prefix_len++;
}
return prefix_len;
}
static size_t common_suffix_len(const std::string & left, const std::string & right) {
size_t suffix_len = 0;
size_t min_len = std::min(left.length(), right.length());
while (suffix_len < min_len && left[left.length() - 1 - suffix_len] == right[right.length() - 1 - suffix_len]) {
suffix_len++;
}
return suffix_len;
}
diff_split calculate_diff_split(const std::string & left, const std::string & right) {
diff_split result;
auto left_seg = segmentize_markers(left);
auto right_seg = segmentize_markers(right);
if (left_seg.empty()) {
result.right = right;
return result;
}
if (right_seg.empty()) {
result.left = left;
return result;
}
auto left_start = left_seg.begin();
auto left_end = --left_seg.end();
auto right_start = right_seg.begin();
auto right_end = --right_seg.end();
auto test = [&] () {
return left_start != left_end && right_start != right_end;
};
bool left_fully_consumed = false;
bool right_fully_consumed = false;
while (test()) {
bool advanced = false;
if (*left_start == *right_start) {
result.prefix.append(left_start->value);
left_start++;
right_start++;
advanced = true;
}
if (*left_end == *right_end) {
result.suffix = left_end->value + result.suffix;
if (left_start != left_end) {
left_end--;
} else {
left_fully_consumed = true;
}
if (right_start != right_end) {
right_end--;
} else {
right_fully_consumed = true;
}
advanced = true;
}
if (!advanced) {
break;
}
}
if (left_start == left_end && right_start != right_end) {
if (*left_start == *right_end) {
result.suffix = right_end->value + result.suffix;
right_end--;
left_fully_consumed = true;
} else if (*left_start == *right_start) {
result.prefix.append(right_start->value);
right_start++;
left_fully_consumed = true;
}
} else if (right_start == right_end && left_start != left_end) {
if (*left_end == *right_start) {
result.suffix = left_end->value + result.suffix;
left_end--;
right_fully_consumed = true;
} else if (*left_start == *right_start) {
result.prefix.append(left_start->value);
left_start++;
right_fully_consumed = true;
}
} else if (left_start == left_end && right_start == right_end && *left_start == *right_start && left_start->type == segment_type::MARKER) {
result.prefix.append(right_start->value);
left_fully_consumed = true;
right_fully_consumed = true;
}
auto eat_segment = [](std::string str, const segment & seg) -> std::string { return std::move(str) + seg.value; };
bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT;
bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT;
std::string remainder_left = std::accumulate(left_start, left_fully_consumed ? left_end : ++left_end, std::string(), eat_segment);
std::string remainder_right = std::accumulate(right_start, right_fully_consumed ? right_end : ++right_end, std::string(), eat_segment);
size_t suffix_len = can_have_text_suffix ? common_suffix_len(remainder_left, remainder_right) : 0;
// avoid overlaps between prefix and suffix
size_t prefix_len = can_have_text_prefix ? common_prefix_len(remainder_left.substr(0, remainder_left.size() - suffix_len),
remainder_right.substr(0, remainder_right.size() - suffix_len)) : 0;
result.prefix.append(remainder_left.substr(0, prefix_len));
result.suffix = remainder_left.substr(remainder_left.length() - suffix_len, suffix_len) + result.suffix;
result.left = remainder_left.substr(prefix_len, remainder_left.length() - prefix_len - suffix_len);
result.right = remainder_right.substr(prefix_len, remainder_right.length() - prefix_len - suffix_len);
if (result.left == "" && result.right == "") {
// degenerate case, no diff
result.prefix = left;
result.suffix = "";
// pick prefix = all as representation
}
return result;
}
// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right`
std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right) {
// Find the common prefix of left and right
size_t common_prefix_len = 0;
size_t min_len = std::min(left.length(), right.length());
while (common_prefix_len < min_len && left[common_prefix_len] == right[common_prefix_len]) {
common_prefix_len++;
}
// If there's no common prefix, return empty string
if (common_prefix_len == 0) {
return "";
}
// Find the common prefix in the full string
std::string common_prefix = left.substr(0, common_prefix_len);
size_t pos = full.find(common_prefix);
// If not found, return empty string
if (pos == std::string::npos) {
return "";
}
// Return everything before the common prefix
return full.substr(0, pos);
}
// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right`
std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right) {
// Find the common suffix of left and right (compare from the end)
size_t common_suffix_len = 0;
size_t min_len = std::min(left.length(), right.length());
while (common_suffix_len < min_len &&
left[left.length() - 1 - common_suffix_len] == right[right.length() - 1 - common_suffix_len]) {
common_suffix_len++;
}
// If there's no common suffix, return empty string
if (common_suffix_len == 0) {
return "";
}
// Extract the common suffix
std::string common_suffix = left.substr(left.length() - common_suffix_len);
// Find the last occurrence of the common suffix in the full string
size_t pos = full.rfind(common_suffix);
// If not found, return empty string
if (pos == std::string::npos) {
return "";
}
// Return everything after the common suffix
return full.substr(pos + common_suffix_len);
}
// TODO: segmentize will treat a JSON array inside tags as a tag: <calls>[{ "fun": { ... } }]</calls> will be three markers
// not too worried about that because it hasn't turned out as a problem anywhere, but noting here in case it will
// Might have to put some restrictions on tag contents as well (like "no { }")
std::vector<segment> segmentize_markers(const std::string & text) {
std::vector<segment> retval;
bool in_marker = false;
char marker_opener = '\0';
auto is_marker_opener = [](char c) -> bool { return c == '<' || c == '['; };
auto is_marker_closer = [](char op, char c) -> bool { return (op == '<' && c == '>') || (op == '[' && c == ']'); };
size_t last_border = 0;
for (size_t cur_pos = 0; cur_pos < text.length(); cur_pos++) {
if (!in_marker && is_marker_opener(text[cur_pos])) {
if (last_border < cur_pos) {
retval.push_back(segment(segment_type::TEXT, text.substr(last_border, cur_pos - last_border)));
}
last_border = cur_pos;
in_marker = true;
marker_opener = text[cur_pos];
} else if (in_marker && is_marker_closer(marker_opener, text[cur_pos])) {
// no need to check because last_border will always be smaller
retval.push_back(segment(segment_type::MARKER, text.substr(last_border, cur_pos - last_border + 1)));
last_border = cur_pos + 1;
in_marker = false;
marker_opener = '\0';
}
}
if (last_border < text.length()) {
retval.push_back(segment(segment_type::TEXT, text.substr(last_border)));
}
return retval;
}
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments) {
std::vector<segment> result;
for (const auto & seg : segments) {
if (!trim_whitespace(seg.value).empty()) {
result.push_back(seg);
}
}
return result;
}
namespace autoparser {
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
templates_params tmpl_params;
tmpl_params.messages = params.messages;
tmpl_params.tools = params.tools;
tmpl_params.add_generation_prompt = params.add_generation_prompt;
tmpl_params.enable_thinking = params.enable_thinking;
if (params.extra_context) {
tmpl_params.extra_context = *params.extra_context;
}
tmpl_params.extra_context["enable_thinking"] = params.enable_thinking;
try {
return common_chat_template_direct_apply(tmpl, tmpl_params);
} catch (const std::exception & e) {
LOG_DBG("Template application failed: %s\n", e.what());
return "";
}
}
std::optional<compare_variants_result> compare_variants(
const common_chat_template & tmpl,
const template_params & params_A,
const std::function<void(template_params &)> & params_modifier) {
// Create variant B by copying A
template_params params_B = params_A;
// Apply modifier to create variant B
if (params_modifier) {
params_modifier(params_B);
}
// Apply template to both variants
std::string output_A = apply_template(tmpl, params_A);
std::string output_B = apply_template(tmpl, params_B);
// Check for template application failures
if (output_A.empty() || output_B.empty()) {
return std::nullopt;
}
// Calculate diff and return result with both outputs
compare_variants_result result;
result.diff = calculate_diff_split(output_A, output_B);
result.output_A = output_A;
result.output_B = output_B;
return result;
}
} // namespace autoparser

View File

@ -0,0 +1,73 @@
#pragma once
#include "chat-auto-parser.h"
#include <functional>
#include <optional>
#include <string>
std::string trim_whitespace(const std::string & str);
std::string trim_leading_whitespace(const std::string & str);
std::string trim_trailing_whitespace(const std::string & str);
std::string trim_trailing_newlines(const std::string & str);
// calculate a diff split (longest common prefix, longest common suffix excluding prefix,
// mismatched part on the left, mismatched part on the right) between two strings
// account for markers - align prefix and suffix endings so that they end on markers
// * eg.:
// calculate_diff_split("<html><body><div></div></body></html>", "<html><body><p>Something</p></body><html>") ->
// { "prefix": "<html><body>" (not: "<html><body><"), "suffix": "</body></html>", "left": "<div></div>", "right": "<p>Something</p>" }
// calculate_diff_split("<html><body>Something</body></html>", "<html><body></body><html>") ->
// { "prefix": "<html><body>", "suffix": "</body></html>", "left": "Something", "right": "" }
diff_split calculate_diff_split(const std::string & left, const std::string & right);
// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right`
// Returns empty string if there's no common prefix
// * eg.:
// until_common_prefix("really want a FUNCTION call", "FUNCTION alpha", "FUNCTION beta") -> "really want a "
// until_common_prefix("<tool_call>", "<something>", "<something_else>") -> ""
// until_common_prefix("some text", "1234", "abcd") -> ""
// until_common_prefix("one arg two args three args four", "argument alpha", "argument beta") -> "one ""
std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right);
// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right`
// Returns empty string if there's no common suffix
// Mirror function of `until_common_prefix`
// * eg.:
// after_common_suffix("really want a FUNCTION call", "first FUNCTION", "second FUNCTION") -> " call"
// after_common_suffix("one arg two-args three args four", "alpha-args", "beta-args") -> " three args four"
std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right);
// Segmentize text into markers and non-marker fragments
// * eg.:
// segmentize_markers("<html><head><title>The site title</title><body><div>Here's some <b>content</b></div></body></html>" ->
// [ (MARKER, "<html>"), (MARKER, "<head>"), (MARKER, "<title>"), (TEXT, "The site title"), (MARKER, "</title>"),
// (MARKER, "<body>"), (MARKER, "<div>"), (TEXT, "Here's some "), (MARKER, "<b>"), (TEXT, "content"), (MARKER, "</b>"),
// (MARKER, "</div>"), (MARKER, "</body>"), (MARKER, "</html>")
// ]
// segmentize_markers("<|tool_call|>[args]{ are here }[/args]<|tool_call_end|>") ->
// [ (MARKER, "<|tool_call|>"), (MARKER, "[args]"), (TEXT, "{ are here }"), (MARKER, "[/args]"), (MARKER, "<|tool_call_end|>") ]
std::vector<segment> segmentize_markers(const std::string & text);
// Prune whitespace-only segments from a vector of segments
// * eg.:
// segmentize_markers("<tool_call>\n<function=foo>\n<arg=bar>\n \n</arg>\n</function>\n</tool_call>") ->
// X = [ (MARKER, "<tool_call>"), (TEXT, "\n"), (MARKER, "<function=foo>"), (TEXT, "\n"), (MARKER, "<arg=bar>"), (TEXT, "\n \n"),
// (MARKER, "</arg>"), (TEXT, "\n"), (MARKER, "</function>"), (TEXT, "\n"), (MARKER, "</tool_call>") ]
// prune_whitespace_segments(X) -> [ (MARKER, "<tool_call>"), (MARKER, "<function=foo>"), (MARKER, "<arg=bar>"), (MARKER, "</arg>"),
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
namespace autoparser {
// Apply a template with the given parameters, returning the rendered string (empty on failure)
std::string apply_template(const common_chat_template & tmpl, const template_params & params);
// Factorized differential comparison function
// Takes base params and a single modifier lambda to create variant B
// Returns compare_variants_result containing diff and both outputs, or std::nullopt on failure
std::optional<compare_variants_result> compare_variants(
const common_chat_template & tmpl,
const template_params & params_A,
const std::function<void(template_params &)> & params_modifier);
} // namespace autoparser

433
common/chat-auto-parser.h Normal file
View File

@ -0,0 +1,433 @@
#pragma once
#include "chat.h"
#include "common.h"
#include "jinja/caps.h"
#include "peg-parser.h"
#include <chrono>
#include <optional>
#include <string>
#include <utility>
#include <vector>
using json = nlohmann::ordered_json;
class common_chat_peg_builder;
// ============================================================================
// Parameters for template application (low-level, used by diff analysis)
// ============================================================================
struct template_params {
json messages;
json tools;
bool add_generation_prompt = false;
bool enable_thinking = true;
std::optional<json> extra_context = std::nullopt;
};
struct diff_split {
std::string prefix;
std::string suffix;
std::string left;
std::string right;
bool operator==(struct diff_split & other) const {
return prefix == other.prefix && suffix == other.suffix && left == other.left && right == other.right;
}
};
// Result of compare_variants containing diff and original outputs
struct compare_variants_result {
diff_split diff;
std::string output_A;
std::string output_B;
};
namespace autoparser {
// ============================================================================
// High-level params for parser generation
// ============================================================================
struct templates_params {
json messages;
json tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
json json_schema;
bool parallel_tool_calls = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
bool stream = true;
std::string grammar;
bool add_generation_prompt = false;
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
json extra_context;
bool add_bos = false;
bool add_eos = false;
bool is_inference = true;
bool add_inference = false;
bool mark_input = true; // whether to mark input strings in the jinja context
};
// ============================================================================
// Analysis Result Enums
// ============================================================================
// Reasoning handling mode (derived from R1-R3 comparisons)
enum class reasoning_mode {
NONE, // No reasoning markers detected
TAG_BASED, // Standard tag-based: <think>...</think>
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
// with both opened and closed tag for disabled thinking
TOOLS_ONLY // Only reason on tool calls, not on normal content
};
inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode) {
switch (mode) {
case reasoning_mode::NONE:
return os << "NONE";
case reasoning_mode::TAG_BASED:
return os << "TAG_BASED";
case reasoning_mode::DELIMITER:
return os << "DELIMITER";
case reasoning_mode::FORCED_OPEN:
return os << "FORCED_OPEN";
case reasoning_mode::FORCED_CLOSED:
return os << "FORCED_CLOSED";
case reasoning_mode::TOOLS_ONLY:
return os << "TOOLS_ONLY";
default:
return os << "UNKNOWN";
}
}
// Content wrapping mode (derived from C1 comparison)
enum class content_mode {
PLAIN, // No content markers
ALWAYS_WRAPPED, // Content always wrapped with markers
WRAPPED_WITH_REASONING, // Content wrapped only when reasoning present
};
inline std::ostream & operator<<(std::ostream & os, const content_mode & mode) {
switch (mode) {
case content_mode::PLAIN:
return os << "PLAIN";
case content_mode::ALWAYS_WRAPPED:
return os << "ALWAYS_WRAPPED";
case content_mode::WRAPPED_WITH_REASONING:
return os << "WRAPPED_WITH_REASONING";
default:
return os << "UNKNOWN";
}
}
// Call ID position in tool calls (for non-JSON formats)
enum class call_id_position {
NONE, // No call ID support detected
PRE_FUNC_NAME, // Call ID before function name: [CALL_ID]id[FUNC]name{args}
BETWEEN_FUNC_AND_ARGS, // Call ID between function and args: [FUNC]name[CALL_ID]id{args}
POST_ARGS, // Call ID after arguments: [FUNC]name{args}[CALL_ID]id
};
inline std::ostream & operator<<(std::ostream & os, const call_id_position & pos) {
switch (pos) {
case call_id_position::NONE:
return os << "NONE";
case call_id_position::PRE_FUNC_NAME:
return os << "PRE_FUNC_NAME";
case call_id_position::BETWEEN_FUNC_AND_ARGS:
return os << "BETWEEN_FUNC_AND_ARGS";
case call_id_position::POST_ARGS:
return os << "POST_ARGS";
default:
return os << "UNKNOWN";
}
}
// Tool call format classification (derived from T1-T5, A1-A3 comparisons)
enum class tool_format {
NONE, // No tool support detected
JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}}
TAG_WITH_JSON, // Tag-based with JSON args: <function=X>{...}</function>
TAG_WITH_TAGGED, // Tag-based with tagged args: <param=key>value</param>
};
inline std::ostream & operator<<(std::ostream & os, const tool_format & format) {
switch (format) {
case tool_format::NONE:
return os << "NONE";
case tool_format::JSON_NATIVE:
return os << "JSON_NATIVE";
case tool_format::TAG_WITH_JSON:
return os << "TAG_WITH_JSON";
case tool_format::TAG_WITH_TAGGED:
return os << "TAG_WITH_TAGGED";
default:
return os << "UNKNOWN";
}
}
// ============================================================================
// Sub-structs for tool analysis
// ============================================================================
struct tool_format_analysis {
tool_format mode = tool_format::NONE;
std::string section_start; // e.g., "<tool_call>", "[TOOL_CALLS]", ""
std::string section_end; // e.g., "</tool_call>", ""
std::string per_call_start; // e.g., "<|tool_call_begin|>", "" (for multi-call templates)
std::string per_call_end; // e.g., "<|tool_call_end|>", ""
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
std::string function_field = "function";
std::string name_field = "name";
std::string args_field = "arguments";
std::string id_field;
std::string gen_id_field;
std::vector<std::string> parameter_order;
};
struct tool_function_analysis {
std::string name_prefix; // e.g., "<function=", "\"name\": \"", "functions."
std::string name_suffix; // e.g., ">", "\"", ":0"
std::string close; // e.g., "</function>", "" (for tag-based)
};
struct tool_arguments_analysis {
std::string start; // e.g., "<|tool_call_argument_begin|>", "<args>"
std::string end; // e.g., "<|tool_call_argument_end|>", "</args>"
std::string name_prefix; // e.g., "<param=", "<arg_key>", "\""
std::string name_suffix; // e.g., ">", "</arg_key>", "\":"
std::string value_prefix; // e.g., "", "<arg_value>", ""
std::string value_suffix; // e.g., "</param>", "</arg_value>", ""
std::string separator; // e.g., "", "\n", ","
};
struct tool_id_analysis {
call_id_position pos = call_id_position::NONE;
std::string prefix; // e.g., "[CALL_ID]" (marker before call ID value)
std::string suffix; // e.g., "" (marker after call ID value, before next section)
};
// ============================================================================
// Parser build context (shared interface for build_parser methods)
// ============================================================================
struct analyze_content;
struct parser_build_context {
common_chat_peg_builder & p;
const templates_params & inputs;
common_peg_parser reasoning_parser;
bool extracting_reasoning = false;
const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
};
// ============================================================================
// Base class for analyzers with parser building
// ============================================================================
struct analyze_base {
virtual ~analyze_base() = default;
virtual common_peg_parser build_parser(parser_build_context & ctx) const = 0;
protected:
const common_chat_template * tmpl = nullptr;
analyze_base() = default;
explicit analyze_base(const common_chat_template & tmpl) : tmpl(&tmpl) {}
};
// ============================================================================
// Reasoning analyzer
// ============================================================================
struct analyze_reasoning : analyze_base {
reasoning_mode mode = reasoning_mode::NONE;
std::string start; // e.g., "<think>", "[THINK]", "<|START_THINKING|>", ""
std::string end; // e.g., "</think>", "[BEGIN FINAL RESPONSE]", "<|END_THINKING|>"
analyze_reasoning() = default;
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
common_peg_parser build_parser(parser_build_context & ctx) const override;
private:
// Look for reasoning markers in rendered content
void compare_reasoning_presence();
// Compare generation prompt with enable_thinking=true vs false
void compare_thinking_enabled();
// Check if reasoning is always possible or only in tool calls
void compare_reasoning_scope();
};
// ============================================================================
// Content analyzer
// ============================================================================
struct analyze_content : analyze_base {
content_mode mode = content_mode::PLAIN;
std::string start; // e.g., "<response>", ">>>all\n", ""
std::string end; // e.g., "</response>", ""
bool requires_nonnull_content = false;
analyze_content() = default;
analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning);
common_peg_parser build_parser(parser_build_context & ctx) const override;
bool is_always_wrapped() const;
common_peg_parser build_optional_wrapped(parser_build_context & ctx) const;
};
// ============================================================================
// Tool analyzer
// ============================================================================
struct analyze_tools : analyze_base {
tool_format_analysis format;
tool_function_analysis function;
tool_arguments_analysis arguments;
tool_id_analysis call_id;
analyze_tools() = default;
analyze_tools(const common_chat_template & tmpl,
const jinja::caps & caps,
const analyze_reasoning & reasoning);
common_peg_parser build_parser(parser_build_context & ctx) const override;
private:
// Extract tool calling 'haystack' for further analysis and delegate further analysis based on format
void analyze_tool_calls(const analyze_reasoning & reasoning);
// Analyze format based on position of function and argument name in needle
void analyze_tool_call_format(const std::string & haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle,
const analyze_reasoning & reasoning);
// Analyze specifics of JSON native format (entire tool call is a JSON object)
void analyze_tool_call_format_json_native(const std::string & clean_haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle);
// Analyze specifics of non-JSON native format (tags for function name or for function name and arguments)
void analyze_tool_call_format_non_json(const std::string & clean_haystack,
const std::string & fun_name_needle);
// Check for and extract specific per-call markers for non-native-JSON templates with parallel call support
void check_per_call_markers();
// Extract function name markers
void extract_function_markers();
// Delegates to separate functions for: separator analysis, argument name analysis, argument value analysis
void analyze_arguments();
// Extract argument name markers
void extract_argument_name_markers();
// Extract argument value markers
void extract_argument_value_markers();
// Extract argument separator, if specified (eg. <arg=foo>...</arg><sep><arg=bar>...</arg>)
void extract_argument_separator();
// Extract argument wrapper markers, if present (eg. '<args><arg=foo>...</arg><arg=bar>...</arg></args>')
void extract_args_markers();
// Extract call ID markers, if present
void extract_call_id_markers();
// Per-format tool parser builders
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
};
// ============================================================================
// Main autoparser class
// ============================================================================
struct autoparser {
jinja::caps jinja_caps;
analyze_reasoning reasoning;
analyze_content content;
analyze_tools tools;
bool analysis_complete = false;
// Preserved tokens for tokenizer (union of all non-empty markers)
std::vector<std::string> preserved_tokens;
autoparser() = default;
// Run full differential analysis on a template
void analyze_template(const common_chat_template & tmpl);
// Build the PEG parser for this template
common_peg_arena build_parser(const templates_params & inputs) const;
private:
// Collect tokens from entire analysis to preserve
void collect_preserved_tokens();
};
// ============================================================================
// Parser generator
// ============================================================================
class peg_generator {
public:
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs);
static common_chat_params generate_parser(const common_chat_template & tmpl,
const struct templates_params & inputs,
const autoparser & autoparser);
};
} // namespace autoparser
enum segment_type { TEXT, MARKER };
inline std::ostream & operator<<(std::ostream & os, const segment_type & type) {
switch (type) {
case segment_type::TEXT:
return os << "TEXT";
case segment_type::MARKER:
return os << "MARKER";
default:
return os << "UNKNOWN";
}
}
struct segment {
segment_type type;
std::string value;
segment(segment_type type, std::string value) : type(type), value(std::move(value)) {}
bool operator==(const segment & other) const {
return type == other.type && value == other.value;
}
bool operator!=(const segment & other) const {
return !(*this == other);
}
};

File diff suppressed because it is too large Load Diff

View File

@ -1,879 +0,0 @@
#include "chat.h"
#include "chat-parser.h"
#include "common.h"
#include "json-partial.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
class xml_toolcall_syntax_exception : public std::runtime_error {
public:
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
};
template<typename T>
inline void sort_uniq(std::vector<T> &vec) {
std::sort(vec.begin(), vec.end());
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
template<typename T>
inline bool all_space(const T &str) {
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
}
static size_t utf8_truncate_safe(const std::string_view s) {
size_t len = s.size();
if (len == 0) return 0;
size_t i = len;
for (size_t back = 0; back < 4 && i > 0; ++back) {
--i;
unsigned char c = s[i];
if ((c & 0x80) == 0) {
return len;
} else if ((c & 0xC0) == 0xC0) {
size_t expected_len = 0;
if ((c & 0xE0) == 0xC0) expected_len = 2;
else if ((c & 0xF0) == 0xE0) expected_len = 3;
else if ((c & 0xF8) == 0xF0) expected_len = 4;
else return i;
if (len - i >= expected_len) {
return len;
} else {
return i;
}
}
}
return len - std::min(len, size_t(3));
}
inline void utf8_truncate_safe_resize(std::string &s) {
s.resize(utf8_truncate_safe(s));
}
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
return s.substr(0, utf8_truncate_safe(s));
}
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
if (literal1.size() == 0) return builder.try_find_literal(literal2);
const auto saved_pos = builder.pos();
while (auto res = builder.try_find_literal(literal1)) {
builder.consume_spaces();
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
res->prelude = builder.str({saved_pos, res->groups[0].begin});
}
builder.move_to(builder.pos() + match_len);
res->groups[0].end = builder.pos();
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
return res;
}
builder.move_to(res->groups[0].begin + 1);
}
builder.move_to(saved_pos);
return std::nullopt;
}
/**
* make a GBNF that accept any strings except those containing any of the forbidden strings.
*/
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
if (c == '\\' || c == ']' || c == '^' || c == '-') {
std::string s = "\\";
s.push_back((char)c);
return s;
}
if (isprint(c)) {
return std::string(1, (char)c);
}
char buf[16];
snprintf(buf, 15, "\\x%02X", c);
return std::string(buf);
};
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
int i = l;
while (i < r) {
const std::string &s = forbids[i];
if ((int)s.size() == depth) {
++i;
continue;
}
unsigned char c = (unsigned char)s[depth];
int j = i;
while (j < r && (int)forbids[j].size() > depth &&
(unsigned char)forbids[j][depth] == c) {
++j;
}
children.push_back({c, {i, j}});
i = j;
}
std::vector<std::string> alts;
if (!children.empty()) {
std::string cls;
for (auto &ch : children) cls += charclass_escape(ch.first);
alts.push_back(std::string("[^") + cls + "]");
}
for (auto &ch : children) {
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
if (!childExpr.empty()) {
std::string quoted_ch = "\"";
if (ch.first == '\\') quoted_ch += "\\\\";
else if (ch.first == '"') quoted_ch += "\\\"";
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
else {
char buf[16];
snprintf(buf, 15, "\\x%02X", ch.first);
quoted_ch += buf;
}
quoted_ch += "\"";
std::string branch = quoted_ch + std::string(" ") + childExpr;
alts.push_back(branch);
}
}
if (alts.empty()) return "";
std::ostringstream oss;
oss << "( ";
for (size_t k = 0; k < alts.size(); ++k) {
if (k) oss << " | ";
oss << alts[k];
}
oss << " )";
return oss.str();
};
if (forbids.empty()) return "( . )*";
sort(forbids.begin(), forbids.end());
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
if (expr.empty()) {
std::string cls;
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
expr = std::string("( [^") + cls + "] )";
}
if (forbids.size() == 1)
return expr + "*";
else
return std::string("( ") + expr + " )*";
}
/**
* Build grammar for xml-style tool call
* form.scope_start and form.scope_end can be empty.
* Requires data.format for model-specific hacks.
*/
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
GGML_ASSERT(!form.tool_start.empty());
GGML_ASSERT(!form.tool_sep.empty());
GGML_ASSERT(!form.key_start.empty());
GGML_ASSERT(!form.val_end.empty());
GGML_ASSERT(!form.tool_end.empty());
std::string key_val_sep = form.key_val_sep;
if (form.key_val_sep2) {
key_val_sep += "\n";
key_val_sep += *form.key_val_sep2;
}
GGML_ASSERT(!key_val_sep.empty());
if (tools.is_array() && !tools.empty()) {
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
auto string_arg_val = form.last_val_end ?
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
continue;
}
const auto & function = tool.at("function");
if (!function.contains("name") || !function.at("name").is_string()) {
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
continue;
}
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
continue;
}
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
struct parameter_rule {
std::string symbol_name;
bool is_required;
};
std::vector<parameter_rule> arg_rules;
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
continue;
} else {
std::vector<std::string> requiredParameters;
if (parameters.contains("required")) {
try { parameters.at("required").get_to(requiredParameters); }
catch (const std::runtime_error&) {
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
}
}
sort_uniq(requiredParameters);
for (const auto & [key, value] : parameters.at("properties").items()) {
std::string quoted_key = key;
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
quoted_key = gbnf_format_literal(key);
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
}
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
gbnf_format_literal(form.key_start) + " " +
gbnf_format_literal(quoted_key) + " " +
gbnf_format_literal(key_val_sep) + " " +
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
(form.raw_argval ?
string_arg_val :
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
) :
builder.add_schema(name + "-arg-" + key, value)
)
), required});
}
}
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
decltype(next_arg_with_sep) next_arg = "\"\"";
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
);
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
);
}
std::string quoted_name = name;
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
quoted_name = gbnf_format_literal(name);
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
}
quoted_name = gbnf_format_literal(quoted_name);
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
}
tool_rules.push_back(builder.add_rule(name + "-call",
gbnf_format_literal(form.tool_start) + " " +
quoted_name + " " +
gbnf_format_literal(form.tool_sep) + " " +
next_arg
));
}
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
builder.add_rule("root",
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
tool_call_multiple_with_end + "?" +
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
);
});
// grammar trigger for tool call
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
}
}
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
GGML_ASSERT(!form.tool_start.empty());
GGML_ASSERT(!form.key_start.empty());
GGML_ASSERT(!form.key_val_sep.empty());
GGML_ASSERT(!form.val_end.empty());
GGML_ASSERT(!form.tool_end.empty());
// Helper to choose return false or throw error
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
if (recovery) {
builder.move_to(start_pos);
return false;
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the models output.");
};
// Drop substring from needle to end from a JSON
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
auto pos = json_str.rfind(needle);
if (pos == std::string::npos) {
return false;
}
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
unsigned char ch = static_cast<unsigned char>(json_str[i]);
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
return false;
}
}
if (pos != 0 && json_str[pos - 1] == '"') {
--pos;
}
json_str.resize(pos);
return true;
};
// Helper to generate a partial argument JSON
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
auto rest = builder.consume_rest();
utf8_truncate_safe_resize(rest);
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
auto tool_str = arguments.dump();
if (partial_json(tool_str)) {
if (builder.add_tool_call(function_name, "", tool_str)) {
return;
}
}
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
};
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
constexpr auto try_find_close = [](
common_chat_msg_parser & builder,
const std::string & end,
const std::optional<std::string> & alt_end,
const std::string & end_next,
const std::optional<std::string> & alt_end_next
) {
auto saved_pos = builder.pos();
auto tc = builder.try_find_literal(end);
auto val_end_size = end.size();
if (alt_end) {
auto pos_1 = builder.pos();
builder.move_to(saved_pos);
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
if (alt_end_next) {
builder.move_to(saved_pos);
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
tc2 = tc3;
}
}
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
tc = tc2;
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
builder.move_to(tc->groups[0].end);
val_end_size = alt_end->size();
} else {
builder.move_to(pos_1);
}
}
return std::make_pair(val_end_size, tc);
};
// Helper to find a val_end or last_val_end, returns matched pattern size
const auto try_find_val_end = [try_find_close, &builder, &form]() {
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
};
// Helper to find a tool_end or last_tool_end, returns matched pattern size
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
};
bool recovery = true;
const auto start_pos = builder.pos();
if (!all_space(form.scope_start)) {
if (auto tc = builder.try_find_literal(form.scope_start)) {
if (all_space(tc->prelude)) {
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
} else {
builder.move_to(start_pos);
return false;
}
} else return false;
}
while (auto tc = builder.try_find_literal(form.tool_start)) {
if (!all_space(tc->prelude)) {
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
gbnf_format_literal(form.tool_start).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
builder.move_to(tc->groups[0].begin - tc->prelude.size());
break;
}
// Find tool name
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
if (!func_name) {
auto [sz, tc] = try_find_tool_end();
func_name = tc;
}
if (!func_name) {
// Partial tool name not supported
throw common_chat_msg_partial_exception("incomplete tool_call");
}
// If the model generate multiple tool call and the first tool call has no argument
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
auto [sz, tc] = try_find_tool_end();
func_name = tc;
}
// Parse tool name
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
std::string function_name = string_strip(func_name->prelude);
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
if (string_starts_with(function_name, "functions.")) {
static const std::regex re(":\\d+$");
if (std::regex_search(function_name, re)) {
function_name = function_name.substr(10, function_name.rfind(":") - 10);
}
}
}
// Argument JSON
json arguments = json::object();
// Helper to generate a partial argument JSON
const auto gen_partial_args = [&](auto set_partial_arg) {
gen_partial_json(set_partial_arg, arguments, builder, function_name);
};
// Parse all arg_key/arg_value pairs
while (auto tc = builder.try_find_literal(form.key_start)) {
if (!all_space(tc->prelude)) {
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
gbnf_format_literal(form.key_start).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
builder.move_to(tc->groups[0].begin - tc->prelude.size());
break;
}
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
auto tool_call_arg = arguments.dump();
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
tool_call_arg.resize(tool_call_arg.size() - 1);
}
builder.add_tool_call(function_name, "", tool_call_arg);
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
}
// Parse arg_key
auto key_res = builder.try_find_literal(form.key_val_sep);
if (!key_res) {
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
}
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
}
auto &key = key_res->prelude;
recovery = false;
// Parse arg_value
if (form.key_val_sep2) {
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
gbnf_format_literal(tc->prelude).c_str(),
gbnf_format_literal(form.key_val_sep).c_str(),
gbnf_format_literal(*form.key_val_sep2).c_str()
);
return return_error(builder, start_pos, false);
}
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
}
} else {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
}
}
auto val_start = builder.pos();
// Test if arg_val is a partial JSON
std::optional<common_json> value_json = std::nullopt;
if (!form.raw_argval || !*form.raw_argval) {
try { value_json = builder.try_consume_json(); }
catch (const std::runtime_error&) { builder.move_to(val_start); }
// TODO: Delete this when json_partial adds top-level support for null/true/false
if (builder.pos() == val_start) {
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
builder.consume_spaces();
std::string_view sv = utf8_truncate_safe_view(builder.input());
sv.remove_prefix(builder.pos());
std::string rest = "a";
if (sv.size() < 6) rest = sv;
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
value_json = {123, {"123", "123"}};
builder.consume_rest();
} else {
builder.move_to(val_start);
}
}
}
// If it is a JSON and followed by </arg_value>, parse as json
// cannot support streaming because it may be a plain text starting with JSON
if (value_json) {
auto json_end = builder.pos();
builder.consume_spaces();
if (builder.pos() == builder.input().size()) {
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
arguments[key] = value_json->json;
auto json_str = arguments.dump();
if (!value_json->healing_marker.json_dump_marker.empty()) {
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
} else {
GGML_ASSERT(json_str.back() == '}');
json_str.resize(json_str.size() - 1);
}
builder.add_tool_call(function_name, "", json_str);
} else {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
}
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
}
builder.move_to(json_end);
auto [val_end_size, tc] = try_find_val_end();
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
} else arguments[key] = value_json->json;
} else builder.move_to(val_start);
}
// If not, parse as plain text
if (val_start == builder.pos()) {
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
auto &value_str = value_plain->prelude;
if (form.trim_raw_argval) value_str = string_strip(value_str);
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
throw common_chat_msg_partial_exception(
"Expected " + gbnf_format_literal(form.val_end) +
" after " + gbnf_format_literal(form.key_val_sep) +
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
);
}
arguments[key] = value_str;
} else {
if (form.trim_raw_argval) {
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
} else {
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
}
throw common_chat_msg_partial_exception(
"Expected " + gbnf_format_literal(form.val_end) +
" after " + gbnf_format_literal(form.key_val_sep) +
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
);
}
}
}
// Consume closing tag
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.tool_end).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
return return_error(builder, start_pos, recovery);
}
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
// Add the parsed tool call
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
}
recovery = false;
continue;
}
}
auto tool_call_arg = arguments.dump();
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
tool_call_arg.resize(tool_call_arg.size() - 1);
}
builder.add_tool_call(function_name, "", tool_call_arg);
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
}
if (auto tc = builder.try_find_literal(form.scope_end)) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.scope_end).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
return return_error(builder, start_pos, recovery);
}
} else {
if (all_space(form.scope_end)) return true;
builder.consume_spaces();
if (builder.pos() == builder.input().size())
throw common_chat_msg_partial_exception("incomplete tool calls");
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.scope_end).c_str(),
gbnf_format_literal(builder.consume_rest()).c_str()
);
return return_error(builder, start_pos, recovery);
}
return true;
}
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
auto pos = pos_;
auto tsize = result_.tool_calls.size();
try { return parse_xml_tool_calls(*this, form); }
catch (const xml_toolcall_syntax_exception&) {}
move_to(pos);
result_.tool_calls.resize(tsize);
return false;
}
/**
* Parse content uses reasoning and XML-Style tool call
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
*/
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
constexpr auto rstrip = [](std::string &s) {
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
};
// Erase substring from l to r, along with additional spaces nearby
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
++l;
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
if (l < r) str[l] = '\n';
if (l + 1 < r) str[l + 1] = '\n';
if (l != 0) l += 2;
str.erase(l, r - l);
return l;
};
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
auto best_match = content.size();
for (auto pattern: list) {
if (pattern.size() == 0) continue;
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
auto match_len = content.size() - match_idx;
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
best_match = match_idx;
}
}
}
if (content.size() > best_match) {
content.erase(best_match);
}
};
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
return trim_suffix(content, {
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
form.scope_end
});
};
// Trim leading spaces without affecting keyword matching
static const common_regex spaces_regex("\\s*");
{
auto tc = builder.consume_regex(spaces_regex);
auto spaces = builder.str(tc.groups[0]);
auto s1 = spaces.size();
trim_potential_partial_word(spaces);
auto s2 = spaces.size();
builder.move_to(builder.pos() - (s1 - s2));
}
// Parse content
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
std::string unclosed_reasoning_content("");
for (;;) {
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
std::string content;
std::string tool_call_start;
if (tc) {
content = std::move(tc->prelude);
tool_call_start = builder.str(tc->groups[0]);
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
} else {
content = builder.consume_rest();
utf8_truncate_safe_resize(content);
}
// Handle unclosed think block
if (reasoning_unclosed) {
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
unclosed_reasoning_content += content;
if (!(form.allow_toolcall_in_think && tc)) {
unclosed_reasoning_content += tool_call_start;
continue;
}
} else {
reasoning_unclosed = false;
std::string reasoning_content;
if (pos == std::string::npos) {
reasoning_content = std::move(content);
} else {
reasoning_content = content.substr(0, pos);
content.erase(0, pos + end_think.size());
}
if (builder.pos() == builder.input().size() && all_space(content)) {
rstrip(reasoning_content);
trim_potential_partial_word(reasoning_content);
rstrip(reasoning_content);
if (reasoning_content.empty()) {
rstrip(unclosed_reasoning_content);
trim_potential_partial_word(unclosed_reasoning_content);
rstrip(unclosed_reasoning_content);
if (unclosed_reasoning_content.empty()) continue;
}
}
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(start_think);
builder.add_content(unclosed_reasoning_content);
builder.add_content(reasoning_content);
if (builder.pos() != builder.input().size() || !all_space(content))
builder.add_content(end_think);
} else {
builder.add_reasoning_content(unclosed_reasoning_content);
builder.add_reasoning_content(reasoning_content);
}
unclosed_reasoning_content.clear();
}
}
// Handle multiple think block
bool toolcall_in_think = false;
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
builder.add_reasoning_content(reasoning_content);
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
} else {
think_start = think_end + end_think.size() - 1;
}
} else {
// This <tool_call> start is in thinking block, skip this tool call
// This <tool_call> start is in thinking block
if (form.allow_toolcall_in_think) {
unclosed_reasoning_content = content.substr(think_start + start_think.size());
} else {
unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start;
}
reasoning_unclosed = true;
content.resize(think_start);
toolcall_in_think = true;
}
}
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
rstrip(content);
// Handle unclosed </think> token from content: delete all </think> token
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
while (pos != std::string::npos) {
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
pos = content.rfind(end_think, pos);
}
}
// Strip if needed
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
content = string_strip(content);
}
}
// remove potential partial suffix
if (builder.pos() == builder.input().size() && builder.is_partial()) {
if (unclosed_reasoning_content.empty()) {
rstrip(content);
trim_potential_partial_word(content);
rstrip(content);
} else {
rstrip(unclosed_reasoning_content);
trim_potential_partial_word(unclosed_reasoning_content);
rstrip(unclosed_reasoning_content);
}
}
// consume unclosed_reasoning_content if allow_toolcall_in_think is set
if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) {
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
builder.add_reasoning_content(unclosed_reasoning_content);
} else {
if (content.empty()) {
content = start_think + unclosed_reasoning_content;
} else {
content += "\n\n" + start_think;
content += unclosed_reasoning_content;
}
}
unclosed_reasoning_content.clear();
}
// Add content
if (!content.empty()) {
// If there are multiple content blocks
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
builder.add_content("\n\n");
}
builder.add_content(content);
}
// This <tool_call> start is in thinking block and toolcall_in_think not set, skip this tool call
if (toolcall_in_think && !form.allow_toolcall_in_think) {
continue;
}
// There is no tool call and all content is parsed
if (!tc) {
GGML_ASSERT(builder.pos() == builder.input().size());
GGML_ASSERT(unclosed_reasoning_content.empty());
if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed);
break;
}
builder.move_to(tc->groups[0].begin);
if (builder.try_consume_xml_tool_calls(form)) {
auto end_of_tool = builder.pos();
builder.consume_spaces();
if (builder.pos() != builder.input().size()) {
builder.move_to(end_of_tool);
if (!builder.result().content.empty()) {
builder.add_content("\n\n");
}
}
} else {
static const common_regex next_char_regex(".");
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
rstrip(c);
builder.add_content(c);
}
}
}
/**
* Parse content uses reasoning and XML-Style tool call
*/
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
}

View File

@ -1,45 +0,0 @@
#pragma once
#include "chat.h"
#include <nlohmann/json.hpp>
#include <optional>
#include <string>
#include <vector>
// Sample config:
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
struct xml_tool_call_format {
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
std::string tool_start; // <invoke name=\" // <tool_call>
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
std::string key_start; // <parameter name=\" // <arg_key>
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
std::string val_end; // </parameter>\n // </arg_value>\n
std::string tool_end; // </invoke>\n // </tool_call>\n
std::string scope_end; // </minimax:tool_call> // // can be empty
// Set this if there can be dynamic spaces inside key_val_sep.
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
std::optional<std::string> key_val_sep2 = std::nullopt;
// Set true if argval should only be raw string. e.g. Hello "world" hi
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
// Defaults to std::nullopt, both will be allowed.
std::optional<bool> raw_argval = std::nullopt;
std::optional<std::string> last_val_end = std::nullopt;
std::optional<std::string> last_tool_end = std::nullopt;
bool trim_raw_argval = false;
bool allow_toolcall_in_think = false;
};
// make a GBNF that accept any strings except those containing any of the forbidden strings.
std::string make_gbnf_excluding(std::vector<std::string> forbids);
/**
* Build grammar for xml-style tool call
* form.scope_start and form.scope_end can be empty.
* Requires data.format for model-specific hacks.
*/
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);

File diff suppressed because it is too large Load Diff

View File

@ -1,133 +0,0 @@
#pragma once
#include "chat.h"
#include "chat-parser-xml-toolcall.h"
#include "json-partial.h"
#include "regex-partial.h"
#include <nlohmann/json_fwd.hpp>
#include <optional>
#include <string>
#include <vector>
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_parser_params syntax_; // TODO: rename to params
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_parser_params & syntax() const { return syntax_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
throw std::runtime_error("Invalid position!");
}
pos_ = pos;
}
void move_back(size_t n) {
if (pos_ < n) {
throw std::runtime_error("Can't move back that far!");
}
pos_ -= n;
}
// Get the substring of the input at the given range
std::string str(const common_string_range & rng) const;
// Appends to the result.content field
void add_content(const std::string & content);
// Appends to the result.reasoning_content field
void add_reasoning_content(const std::string & reasoning_content);
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
void finish();
bool consume_spaces();
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
std::string consume_rest();
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
bool try_consume_literal(const std::string & literal);
std::optional<find_regex_result> try_find_literal(const std::string & literal);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
std::optional<common_json> try_consume_json();
common_json consume_json();
struct consume_json_result {
nlohmann::ordered_json value;
bool is_partial;
};
/*
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
*/
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
// Parse content uses reasoning and XML-Style tool call
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
void clear_tools();
};

View File

@ -1,13 +1,17 @@
#include "chat-peg-parser.h"
#include "chat-auto-parser.h"
#include "ggml.h"
#include "peg-parser.h"
#include <nlohmann/json.hpp>
using json = nlohmann::json;
using json = nlohmann::ordered_json;
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
int count = 0;
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
if (max != -1 && count <= max) {
if (max != -1 && count >= max) {
break;
}
sv.remove_suffix(1);
@ -16,109 +20,820 @@ static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
return sv;
}
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
static std::string_view trim_leading_space(std::string_view sv, int max = -1) {
int count = 0;
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.front()))) {
if (max != -1 && count >= max) {
break;
}
sv.remove_prefix(1);
count++;
}
return sv;
}
static std::string_view trim(std::string_view sv) {
return trim_trailing_space(trim_leading_space(sv, 1));
}
// Count the number of unclosed '{' braces in a JSON-like string,
// properly skipping braces inside quoted strings.
static int json_brace_depth(const std::string & s) {
int depth = 0;
bool in_string = false;
bool escaped = false;
for (char c : s) {
if (escaped) {
escaped = false;
continue;
}
if (c == '\\' && in_string) {
escaped = true;
continue;
}
if (c == '"') {
in_string = !in_string;
continue;
}
if (!in_string) {
if (c == '{') {
depth++;
} else if (c == '}') {
depth--;
}
}
}
return depth;
}
// JSON-escape a string and return the inner content (without surrounding quotes).
static std::string escape_json_string_inner(const std::string & s) {
std::string escaped = json(s).dump();
if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') {
return escaped.substr(1, escaped.size() - 2);
}
return escaped;
}
// Convert Python-style single-quoted strings to JSON double-quoted strings
// Only converts outer string delimiters, properly handling escape sequences:
// - {'key': 'value'} -> {"key": "value"}
// - {'code': 'print(\'hello\')'} -> {"code": "print('hello')"}
// - {'msg': 'He said "hi"'} -> {"msg": "He said \"hi\""}
static std::string normalize_quotes_to_json(const std::string & input) {
std::string result;
result.reserve(input.size() + 16); // May need extra space for escaping
bool in_single_quoted = false;
bool in_double_quoted = false;
for (size_t i = 0; i < input.size(); ++i) {
char c = input[i];
// Handle escape sequences
if (c == '\\' && i + 1 < input.size()) {
char next = input[i + 1];
if (in_single_quoted) {
// Inside a single-quoted string being converted to double quotes
if (next == '\'') {
// \' -> ' (escaped single quote becomes unescaped in double-quoted string)
result += '\'';
++i;
continue;
}
if (next == '"') {
// \" stays as \" (already escaped, works in double-quoted string)
result += "\\\"";
++i;
continue;
}
// Other escapes (\n, \\, etc.): pass through both characters
result += c;
result += next;
++i;
continue;
}
if (in_double_quoted) {
// Inside a double-quoted string - pass through escape sequences as-is
result += c;
result += next;
++i;
continue;
}
// Outside any string - just pass through the backslash
result += c;
continue;
}
// Handle quote characters
if (c == '"') {
if (in_single_quoted) {
// Unescaped double quote inside single-quoted string -> must escape for JSON
result += "\\\"";
} else {
// Double quote as string delimiter or outside strings
in_double_quoted = !in_double_quoted;
result += c;
}
} else if (c == '\'') {
if (in_double_quoted) {
// Single quote inside double-quoted string -> pass through
result += c;
} else if (in_single_quoted) {
// Closing single quote -> convert to double quote
in_single_quoted = false;
result += '"';
} else {
// Opening single quote -> convert to double quote
in_single_quoted = true;
result += '"';
}
} else {
result += c;
}
}
return result;
}
void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
arena.visit(result, [this](const common_peg_ast_node & node) {
map(node);
if (!node.tag.empty()) {
tags[node.tag] = std::string(node.text);
}
});
}
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags) const {
common_peg_parse_context ctx(input, flags | extra_flags);
auto parse_result = arena.parse(ctx);
if (is_reasoning) {
result.reasoning_content = std::string(trim_trailing_space(node.text));
tag_based_peg_mapper mapper;
mapper.from_ast(ctx.ast, parse_result);
return { std::move(parse_result), std::move(mapper.tags) };
}
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
if (input.empty()) {
return parse_and_extract(input);
}
for (size_t i = 0; i < input.size(); i++) {
common_peg_parse_context ctx(input, flags);
auto parse_result = arena.parse(ctx, i);
if (parse_result.success() || i == input.size() - 1) {
tag_based_peg_mapper mapper;
mapper.from_ast(ctx.ast, parse_result);
return { std::move(parse_result), std::move(mapper.tags) };
}
}
GGML_ABORT("Should not happen");
}
if (is_content) {
result.content = std::string(trim_trailing_space(node.text));
tagged_peg_parser build_tagged_peg_parser(
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn) {
common_peg_parser_builder builder;
builder.set_root(fn(builder));
return { builder.build() };
}
common_peg_parser common_chat_peg_builder::tag_with_safe_content(const std::string & tag_name,
const std::string & marker,
const common_peg_parser & p) {
if (marker.empty()) {
return zero_or_more(choice({ p, rule(tag_name, content(any())) }));
}
auto content_chunk = rule(tag_name, content(negate(literal(marker)) + any() + until(marker)));
return zero_or_more(choice({ p, content_chunk }));
}
std::string & common_chat_peg_mapper::args_target() {
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
}
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
const common_peg_parse_result & parse_result_arg) {
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
// Flush any pending tool call that was started but never got a name
// This happens during partial parsing when the tool call is incomplete
if (pending_tool_call.has_value() && !pending_tool_call->name.empty()) {
if (!args_buffer.empty()) {
pending_tool_call->arguments = args_buffer;
}
if (closing_quote_pending && !pending_tool_call->arguments.empty()) {
pending_tool_call->arguments += "\"";
}
result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset();
}
}
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
common_chat_peg_mapper::map(node);
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
// Handle reasoning/content tags
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
if (is_reasoning) { // GPT OSS can have more than 1 reasoning block, so concatenate here
result.reasoning_content += std::string(node.text);
}
if (is_content) {
// Concatenate content from multiple content nodes (e.g., when reasoning markers
// are preserved before content markers in reasoning_format=NONE mode)
result.content += std::string(node.text);
}
// Handle tool-related tags (supporting both JSON and tagged formats)
bool is_tool_open = node.tag == common_chat_peg_builder::TOOL_OPEN;
bool is_tool_close = node.tag == common_chat_peg_builder::TOOL_CLOSE;
bool is_tool_name = node.tag == common_chat_peg_builder::TOOL_NAME;
bool is_tool_id = node.tag == common_chat_peg_builder::TOOL_ID;
bool is_tool_args = node.tag == common_chat_peg_builder::TOOL_ARGS;
bool is_arg_open = node.tag == common_chat_peg_builder::TOOL_ARG_OPEN;
bool is_arg_close = node.tag == common_chat_peg_builder::TOOL_ARG_CLOSE;
bool is_arg_name = node.tag == common_chat_peg_builder::TOOL_ARG_NAME;
bool is_arg_value = node.tag == common_chat_peg_builder::TOOL_ARG_VALUE;
bool is_arg_string_value = node.tag == common_chat_peg_builder::TOOL_ARG_STRING_VALUE;
if (is_tool_open) {
result.tool_calls.emplace_back();
current_tool = &result.tool_calls.back();
pending_tool_call = common_chat_tool_call();
current_tool = &pending_tool_call.value();
arg_count = 0;
args_buffer.clear();
closing_quote_pending = false;
}
if (is_tool_id && current_tool) {
current_tool->id = std::string(trim_trailing_space(node.text));
auto text = trim_trailing_space(node.text);
if (text.size() >= 2 && text.front() == '"' && text.back() == '"') {
text = text.substr(1, text.size() - 2);
}
current_tool->id = std::string(text);
}
if (is_tool_name && current_tool) {
current_tool->name = std::string(trim_trailing_space(node.text));
// Now that we have the name, populate the arguments from the buffer
if (!args_buffer.empty()) {
current_tool->arguments = args_buffer;
args_buffer.clear();
} else if (current_tool->arguments.empty()) {
current_tool->arguments = "{";
}
// Add the tool call to results so streaming can see it
if (pending_tool_call.has_value()) {
result.tool_calls.push_back(pending_tool_call.value());
pending_tool_call.reset();
current_tool = &result.tool_calls.back();
}
}
if (is_tool_args && current_tool) {
current_tool->arguments = std::string(trim_trailing_space(node.text));
}
}
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
common_chat_peg_mapper::map(node);
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
if (is_tool_open) {
result.tool_calls.emplace_back();
current_tool = &result.tool_calls.back();
arg_count = 0;
}
if (is_tool_name) {
current_tool->name = std::string(node.text);
current_tool->arguments = "{";
// For JSON format: arguments come as a complete JSON object
// For tagged format: built up from individual arg_name/arg_value nodes
auto text = trim_trailing_space(node.text);
if (!text.empty() && text.front() == '{') {
args_target() = std::string(text);
}
}
if (is_arg_open) {
needs_closing_quote = false;
closing_quote_pending = false;
}
if (is_arg_name && current_tool) {
std::string arg_entry;
if (arg_count > 0) {
current_tool->arguments += ",";
arg_entry = ",";
}
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
arg_entry += json(trim(node.text)).dump() + ":";
++arg_count;
auto & target = args_target();
if (target.empty()) {
target = "{";
}
target += arg_entry;
}
if (is_arg_string && current_tool) {
// Serialize to JSON, but exclude the end quote
std::string dumped = json(trim_trailing_space(node.text)).dump();
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
needs_closing_quote = true;
if ((is_arg_value || is_arg_string_value) && current_tool) {
std::string value_content = std::string(trim_trailing_space(trim_leading_space(node.text, 1), 1));
std::string value_to_add;
if (value_content.empty() && is_arg_string_value) {
// Empty string value - arg_close will add the closing quote
value_to_add = "\"";
closing_quote_pending = true;
} else if (!value_content.empty() && is_arg_string_value) {
// Schema declares this as string type - always treat as literal string value
if (!closing_quote_pending) {
value_to_add = "\"";
closing_quote_pending = true;
}
value_to_add += escape_json_string_inner(value_content);
} else if (!value_content.empty()) {
// For potential containers, normalize Python-style single quotes to JSON double quotes
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
if (is_potential_container) {
value_content = normalize_quotes_to_json(value_content);
}
// Try to parse as JSON value (number, bool, null, object, array)
try {
json parsed = json::parse(value_content);
if (parsed.is_string()) {
// Don't add closing quote yet (added by arg_close) for monotonic streaming
std::string escaped = parsed.dump();
if (!escaped.empty() && escaped.back() == '"') {
escaped.pop_back();
}
value_to_add = escaped;
closing_quote_pending = true;
} else {
// Non-string values: use raw content to preserve whitespace for monotonicity
value_to_add = value_content;
}
} catch (...) {
if (node.is_partial && is_potential_container) {
// Partial container: pass through the already-normalized content
value_to_add = value_content;
} else {
// Not valid JSON - treat as string value
if (!closing_quote_pending) {
value_to_add = "\"";
closing_quote_pending = true;
}
value_to_add += escape_json_string_inner(value_content);
}
}
}
args_target() += value_to_add;
}
if (is_arg_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
if (closing_quote_pending) {
args_target() += "\"";
closing_quote_pending = false;
}
}
if (is_arg_json && current_tool) {
current_tool->arguments += std::string(trim_trailing_space(node.text));
}
if (is_tool_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
// Flush buffer to arguments if tool name was never seen
if (current_tool->name.empty() && !args_buffer.empty()) {
current_tool->arguments = args_buffer;
args_buffer.clear();
}
// Close any pending string quote
if (closing_quote_pending) {
current_tool->arguments += "\"";
closing_quote_pending = false;
}
// Close any unclosed braces (accounts for nested objects)
for (int d = json_brace_depth(current_tool->arguments); d > 0; d--) {
current_tool->arguments += "}";
}
// Add tool call to results if named; otherwise discard
if (pending_tool_call.has_value()) {
if (!current_tool->name.empty()) {
result.tool_calls.push_back(pending_tool_call.value());
}
pending_tool_call.reset();
}
current_tool->arguments += "}";
}
}
common_peg_parser common_chat_peg_builder::standard_constructed_tools(
const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
bool parallel_tool_calls,
bool force_tool_calls) {
if (!tools.is_array() || tools.empty()) {
return eps();
}
// Extract markers with defaults
auto get_marker = [&markers](const std::string & key, const std::string & default_val = "") -> std::string {
auto it = markers.find(key);
return it != markers.end() ? it->second : default_val;
};
std::string section_start = get_marker("tool_call_start_marker", "<tool_call>");
std::string section_end = get_marker("tool_call_end_marker", "</tool_call>");
std::string func_opener = get_marker("function_opener", "<function=");
std::string func_name_suffix = get_marker("function_name_suffix", ">");
std::string func_closer = get_marker("function_closer", "</function>");
std::string param_key_prefix = get_marker("parameter_key_prefix", "<param=");
std::string param_key_suffix = get_marker("parameter_key_suffix", ">");
std::string param_closer = get_marker("parameter_closer", "</param>");
// Build tool choices for tagged format
auto tool_choices = choice();
for (const auto & tool_def : tools) {
if (!tool_def.contains("function")) {
continue;
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
// Build argument parsers
auto args = eps();
if (params.contains("properties") && !params["properties"].empty()) {
auto arg_choice = choice();
for (const auto & el : params["properties"].items()) {
const std::string & prop_name = el.key();
auto arg_name_parser =
choice({ literal(prop_name), literal("\"" + prop_name + "\""), literal("'" + prop_name + "'") });
auto arg_rule = tool_arg(tool_arg_open(literal(param_key_prefix)) + tool_arg_name(arg_name_parser) +
literal(param_key_suffix) + tool_arg_value(until(param_closer)) +
tool_arg_close(literal(param_closer)));
arg_choice |= arg_rule;
}
args = zero_or_more(arg_choice + space());
}
// Build function parser: <function=name>args</function>
auto tool_parser = tool(tool_open(literal(func_opener) + tool_name(literal(name)) + literal(func_name_suffix)) +
space() + tool_args(args) + space() + tool_close(literal(func_closer)));
tool_choices |= rule("tool-" + name, tool_parser);
}
// Build the section with markers
auto section =
parallel_tool_calls ?
trigger_rule("tool-call", literal(section_start) + space() + one_or_more(tool_choices + space()) +
literal(section_end)) :
trigger_rule("tool-call", literal(section_start) + space() + tool_choices + space() + literal(section_end));
return force_tool_calls ? section : optional(section);
}
// Python-style tool calls: name(arg1="value1", arg2=123)
// Used only by LFM2 for now, so we don't merge it into autoparser
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
const nlohmann::json & tools,
bool parallel_tool_calls) {
if (!tools.is_array() || tools.empty()) {
return eps();
}
auto tool_choices = choice();
for (const auto & tool_def : tools) {
if (!tool_def.contains("function")) {
continue;
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
auto args = eps();
if (params.contains("properties") && !params["properties"].empty()) {
auto arg_choice = choice();
for (const auto & el : params["properties"].items()) {
const std::string & prop_name = el.key();
const auto & prop_def = el.value();
bool is_string_type = (prop_def.contains("type") && prop_def["type"] == "string");
auto arg_name_parser = literal(prop_name);
common_peg_parser arg_value_parser = eps();
auto string_value_parser = choice({
literal("\"") + tool_arg_string_value(json_string_content()) + literal("\""),
literal("'") + tool_arg_string_value(json_string_content()) + literal("'")
});
if (is_string_type) {
arg_value_parser = string_value_parser;
} else {
arg_value_parser = tool_arg_value(python_value());
}
// Full argument: name="value" or name=value
auto arg_rule = tool_arg(
tool_arg_open(eps()) +
tool_arg_name(arg_name_parser) +
literal("=") +
arg_value_parser +
tool_arg_close(eps())
);
arg_choice |= arg_rule;
}
args = arg_choice + zero_or_more("," + space() + arg_choice);
}
auto tool_parser = tool(tool_open(tool_name(literal(name)) + literal("(")) +
space() + tool_args(args) + space() + tool_close(literal(")"))
);
tool_choices |= rule("tool-" + name, tool_parser);
}
if (parallel_tool_calls) {
return "[" + space() + tool_choices + zero_or_more("," + space() + tool_choices) + space() + "]";
}
return "[" + space() + tool_choices + space() + "]";
}
// Helper: Parse dot notation key into prefix and field name
static std::pair<std::string, std::string> parse_key_spec(const std::string & key) {
auto dot_pos = key.find('.');
if (dot_pos == std::string::npos) {
return {"", key}; // Top-level field
}
return {key.substr(0, dot_pos), key.substr(dot_pos + 1)};
}
// Mode 1: function_is_key — parse {"function_name": {...}}
common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
for (const auto & tool_def : tools) {
if (!tool_def.contains("function")) {
continue;
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
// Build inner object fields
std::vector<common_peg_parser> inner_fields;
if (!call_id_key.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
);
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
}
if (!gen_call_id_key.empty()) {
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
tool_id(json_number())
})
);
inner_fields.push_back(optional(gen_id_parser + space() + optional(literal(",") + space())));
}
// Arguments — either wrapped in args_key or parsed directly
common_peg_parser args_parser = eps();
if (args_key.empty()) {
args_parser = tool_args(schema(json(), "tool-" + name + "-schema", params));
} else {
args_parser = literal("\"" + effective_args_key + "\"") + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));
}
inner_fields.push_back(args_parser);
// Build inner object parser
common_peg_parser inner_object = eps();
if (args_key.empty() && inner_fields.size() == 1) {
inner_object = inner_fields[0];
} else {
inner_object = literal("{") + space();
for (size_t i = 0; i < inner_fields.size(); i++) {
inner_object = inner_object + inner_fields[i];
if (i < inner_fields.size() - 1) {
inner_object = inner_object + space();
}
}
inner_object = inner_object + space() + literal("}");
}
auto tool_parser = tool(
tool_open(literal("{")) + space() +
literal("\"") + tool_name(literal(name)) + literal("\"") +
space() + literal(":") + space() +
inner_object +
space() + tool_close(literal("}"))
);
tool_choices |= rule("tool-" + name, tool_parser);
}
return tool_choices;
}
// Mode 2: Nested keys (dot notation like "function.name")
common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
auto name_spec = parse_key_spec(effective_name_key);
auto args_spec = parse_key_spec(effective_args_key);
std::string nested_prefix = !name_spec.first.empty() ? name_spec.first : args_spec.first;
std::string nested_name_field = !name_spec.first.empty() ? name_spec.second : effective_name_key;
std::string nested_args_field = !args_spec.first.empty() ? args_spec.second : effective_args_key;
for (const auto & tool_def : tools) {
if (!tool_def.contains("function")) {
continue;
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));
auto nested_object = literal("{") + space() +
nested_name + space() + literal(",") + space() +
nested_args +
space() + literal("}");
// Format: { id?, "function": {...} }
auto tool_parser_body = tool_open(literal("{")) + space();
if (!call_id_key.empty()) {
auto id_spec = parse_key_spec(call_id_key);
if (id_spec.first.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
);
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
}
}
if (!gen_call_id_key.empty()) {
auto gen_id_spec = parse_key_spec(gen_call_id_key);
if (gen_id_spec.first.empty()) {
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
tool_id(json_number())
})
);
tool_parser_body = tool_parser_body + optional(gen_id_parser + space() + literal(",") + space());
}
}
auto nested_field = literal("\"" + nested_prefix + "\"") + space() + literal(":") + space() + nested_object;
tool_parser_body = tool_parser_body + nested_field + space() + tool_close(literal("}"));
tool_choices |= rule("tool-" + name, tool(tool_parser_body));
}
return tool_choices;
}
// Mode 3: Flat keys with optional ID fields and parameter ordering
common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key,
const std::vector<std::string> & parameters_order) {
auto tool_choices = choice();
auto name_key_parser = literal("\"" + effective_name_key + "\"");
auto args_key_parser = literal("\"" + effective_args_key + "\"");
for (const auto & tool_def : tools) {
if (!tool_def.contains("function")) {
continue;
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
auto tool_args_ = args_key_parser + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));
// Build ID parsers if keys are provided
common_peg_parser id_parser = eps();
if (!call_id_key.empty()) {
id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
tool_id(json_number())
})
);
}
common_peg_parser gen_id_parser = eps();
if (!gen_call_id_key.empty()) {
gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
tool_id(json_number())
})
);
}
// Create (parser, key) pairs for all fields, then sort by parameters_order
std::vector<std::pair<common_peg_parser, std::string>> parser_pairs;
parser_pairs.emplace_back(tool_name_, effective_name_key);
parser_pairs.emplace_back(tool_args_, effective_args_key);
if (!call_id_key.empty()) {
parser_pairs.emplace_back(optional(id_parser), call_id_key);
}
if (!gen_call_id_key.empty()) {
parser_pairs.emplace_back(optional(gen_id_parser), gen_call_id_key);
}
std::sort(parser_pairs.begin(), parser_pairs.end(),
[&parameters_order](const auto & a, const auto & b) {
auto pos_a = std::find(parameters_order.begin(), parameters_order.end(), a.second);
auto pos_b = std::find(parameters_order.begin(), parameters_order.end(), b.second);
size_t idx_a = (pos_a == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_a);
size_t idx_b = (pos_b == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_b);
return idx_a < idx_b;
});
auto ordered_body = tool_open(literal("{")) + space();
for (size_t i = 0; i < parser_pairs.size(); i++) {
ordered_body = ordered_body + parser_pairs[i].first;
if (i < parser_pairs.size() - 1) {
ordered_body = ordered_body + space() + literal(",") + space();
}
}
ordered_body = ordered_body + space() + tool_close(literal("}"));
tool_choices |= rule("tool-" + name, tool(ordered_body));
}
return tool_choices;
}
common_peg_parser common_chat_peg_builder::standard_json_tools(
const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key,
const std::string & args_key,
bool array_wrapped,
bool function_is_key,
const std::string & call_id_key,
const std::string & gen_call_id_key,
const std::vector<std::string> & parameters_order) {
if (!tools.is_array() || tools.empty()) {
return eps();
}
std::string effective_name_key = name_key.empty() ? "name" : name_key;
std::string effective_args_key = args_key.empty() ? "arguments" : args_key;
// Dispatch to the appropriate builder based on the JSON layout mode
common_peg_parser tool_choices = eps();
if (function_is_key) {
tool_choices = build_json_tools_function_is_key(tools, args_key, effective_args_key, call_id_key, gen_call_id_key);
} else {
auto name_spec = parse_key_spec(effective_name_key);
auto args_spec = parse_key_spec(effective_args_key);
if (!name_spec.first.empty() || !args_spec.first.empty()) {
tool_choices = build_json_tools_nested_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key);
} else {
tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order);
}
}
// Build the section with markers
auto tool_calls = tool_choices;
if (parallel_tool_calls) {
tool_calls = tool_calls + zero_or_more(space() + literal(",") + space() + tool_choices);
}
if (array_wrapped) {
tool_calls = literal("[") + space() + tool_calls + space() + literal("]");
}
auto section =
trigger_rule("tool-call", literal(section_start) + space() + tool_calls + space() + literal(section_end));
return force_tool_calls ? section : optional(section);
}

View File

@ -3,22 +3,9 @@
#include "chat.h"
#include "peg-parser.h"
class common_chat_peg_builder : public common_peg_parser_builder {
public:
static constexpr const char * REASONING_BLOCK = "reasoning-block";
static constexpr const char * REASONING = "reasoning";
static constexpr const char * CONTENT = "content";
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
};
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
common_chat_peg_builder builder;
builder.set_root(fn(builder));
return builder.build();
}
#include <map>
#include <optional>
#include <vector>
class common_chat_peg_mapper {
public:
@ -26,80 +13,169 @@ class common_chat_peg_mapper {
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
virtual ~common_chat_peg_mapper() = default;
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
virtual void map(const common_peg_ast_node & node);
private:
// Tool call handling state
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
common_chat_tool_call * current_tool = nullptr;
int arg_count = 0;
bool closing_quote_pending = false;
std::string args_buffer; // Buffer to delay arguments until tool name is known
// Returns a reference to the active argument destination string.
// Before tool_name is known, writes go to args_buffer; after, to current_tool->arguments.
std::string & args_target();
};
class common_chat_peg_native_builder : public common_chat_peg_builder {
public:
static constexpr const char * TOOL = "tool";
static constexpr const char * TOOL_OPEN = "tool-open";
static constexpr const char * TOOL_CLOSE = "tool-close";
static constexpr const char * TOOL_ID = "tool-id";
static constexpr const char * TOOL_NAME = "tool-name";
static constexpr const char * TOOL_ARGS = "tool-args";
struct content_structure;
struct tool_call_structure;
class common_chat_peg_builder : public common_peg_parser_builder {
public:
// Tag constants (from former common_chat_peg_base_builder)
static constexpr const char * REASONING_BLOCK = "reasoning-block";
static constexpr const char * REASONING = "reasoning";
static constexpr const char * CONTENT = "content";
// Tag constants
static constexpr const char * TOOL = "tool";
static constexpr const char * TOOL_OPEN = "tool-open";
static constexpr const char * TOOL_CLOSE = "tool-close";
static constexpr const char * TOOL_ID = "tool-id";
static constexpr const char * TOOL_NAME = "tool-name";
static constexpr const char * TOOL_ARGS = "tool-args";
static constexpr const char * TOOL_ARG = "tool-arg";
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
static constexpr const char * TOOL_ARG_VALUE = "tool-arg-value";
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; // For schema-declared string types
// Low-level tag methods (from former common_chat_peg_base_builder)
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
common_peg_parser tag_with_safe_content(const std::string & tag_name,
const std::string & marker,
const common_peg_parser & p);
// Low-level tag methods
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
};
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
common_chat_tool_call * current_tool;
public:
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
void map(const common_peg_ast_node & node) override;
};
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
common_chat_peg_native_builder builder;
builder.set_root(fn(builder));
return builder.build();
}
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
public:
static constexpr const char * TOOL = "tool";
static constexpr const char * TOOL_OPEN = "tool-open";
static constexpr const char * TOOL_CLOSE = "tool-close";
static constexpr const char * TOOL_NAME = "tool-name";
static constexpr const char * TOOL_ARG = "tool-arg";
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
common_peg_parser tool_arg_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); }
// Use for schema-declared string types - won't be treated as potential JSON container
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); }
// Legacy-compatible helper for building standard JSON tool calls
// Used by tests and manual parsers
// name_key/args_key: JSON key names for function name and arguments
// Empty or "name"/"arguments" will accept both common variations
// Supports dot notation for nested objects (e.g., "function.name")
// array_wrapped: if true, tool calls are wrapped in JSON array [...]
// function_is_key: if true, function name is the JSON key (e.g., {"func_name": {...}})
// call_id_key: JSON key for string call ID (e.g., "id")
// gen_call_id_key: JSON key for generated integer call ID (e.g., "tool_call_id")
// parameters_order: order in which JSON fields should be parsed
common_peg_parser standard_json_tools(const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key = "",
const std::string & args_key = "",
bool array_wrapped = false,
bool function_is_key = false,
const std::string & call_id_key = "",
const std::string & gen_call_id_key = "",
const std::vector<std::string> & parameters_order = {});
// Legacy-compatible helper for building XML/tagged style tool calls
// Used by tests and manual parsers
common_peg_parser standard_constructed_tools(const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
bool parallel_tool_calls,
bool force_tool_calls);
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
// Used by LFM2 and similar templates
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
bool parallel_tool_calls);
private:
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key,
const std::vector<std::string> & parameters_order);
};
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
common_chat_tool_call * current_tool;
int arg_count = 0;
bool needs_closing_quote = false;
public:
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
void map(const common_peg_ast_node & node) override;
};
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
common_chat_peg_constructed_builder builder;
builder.set_root(fn(builder));
return builder.build();
inline common_peg_arena build_chat_peg_parser(
const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
common_chat_peg_builder builder;
builder.set_root(fn(builder));
return builder.build();
}
class tag_based_peg_mapper {
public:
std::map<std::string, std::string> tags;
void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
};
struct tagged_parse_result {
common_peg_parse_result result;
std::map<std::string, std::string> tags;
};
struct tagged_peg_parser {
common_peg_arena arena;
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE;
tagged_peg_parser & withDebug() {
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
return *this;
}
tagged_peg_parser & withoutDebug() {
flags = flags & ~COMMON_PEG_PARSE_FLAG_DEBUG;
return *this;
}
tagged_parse_result parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags = COMMON_PEG_PARSE_FLAG_NONE) const;
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
};
tagged_peg_parser build_tagged_peg_parser(
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);

File diff suppressed because it is too large Load Diff

View File

@ -3,17 +3,30 @@
#pragma once
#include "common.h"
#include "jinja/parser.h"
#include "nlohmann/json_fwd.hpp"
#include "peg-parser.h"
#include <functional>
#include "jinja/runtime.h"
#include "jinja/caps.h"
#include "nlohmann/json.hpp"
#include <chrono>
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <map>
using chat_template_caps = jinja::caps;
using json = nlohmann::ordered_json;
#include <nlohmann/json_fwd.hpp>
struct common_chat_templates;
namespace autoparser {
struct templates_params;
} // namespace autoparser
struct common_chat_tool_call {
std::string name;
std::string arguments;
@ -38,21 +51,85 @@ struct common_chat_msg_content_part {
}
};
struct common_chat_template {
jinja::program prog;
std::string bos_tok;
std::string eos_tok;
std::string src;
chat_template_caps caps;
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(src);
this->prog = jinja::parse_from_tokens(lexer_res);
this->src = lexer_res.source;
this->bos_tok = bos_token;
this->eos_tok = eos_token;
this->caps = jinja::caps_get(prog);
// LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str());
}
const std::string & source() const { return src; }
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }
// TODO: this is ugly, refactor it somehow
json add_system(const json & messages, const std::string & system_prompt) const {
GGML_ASSERT(messages.is_array());
auto msgs_copy = messages;
if (!caps.supports_system_role) {
if (msgs_copy.empty()) {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "user"},
{"content", system_prompt}
});
} else {
auto & first_msg = msgs_copy[0];
if (!first_msg.contains("content")) {
first_msg["content"] = "";
}
first_msg["content"] = system_prompt + "\n\n"
+ first_msg["content"].get<std::string>();
}
} else {
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "system"},
{"content", system_prompt}
});
} else if (msgs_copy[0].at("role") == "system") {
msgs_copy[0]["content"] = system_prompt;
}
}
return msgs_copy;
}
chat_template_caps original_caps() const {
return caps;
}
};
struct common_chat_msg {
std::string role;
std::string content;
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts;
std::vector<common_chat_tool_call> tool_calls;
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
std::vector<common_chat_tool_call> tool_calls;
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() &&
tool_name.empty() && tool_call_id.empty();
}
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
void set_tool_call_ids(std::vector<std::string> & ids_cache,
const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
auto id = tool_calls[i].id;
@ -64,32 +141,28 @@ struct common_chat_msg {
tool_calls[i].id = ids_cache[i];
}
}
bool operator==(const common_chat_msg & other) const {
return role == other.role
&& content == other.content
&& content_parts == other.content_parts
&& tool_calls == other.tool_calls
&& reasoning_content == other.reasoning_content
&& tool_name == other.tool_name
&& tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const {
return !(*this == other);
return role == other.role && content == other.content && content_parts == other.content_parts &&
tool_calls == other.tool_calls && reasoning_content == other.reasoning_content &&
tool_name == other.tool_name && tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const { return !(*this == other); }
};
struct common_chat_msg_diff {
std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv,
const common_chat_msg & msg_new);
bool operator==(const common_chat_msg_diff & other) const {
return content_delta == other.content_delta
&& tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta;
return content_delta == other.content_delta && tool_call_index == other.tool_call_index &&
tool_call_delta == other.tool_call_delta;
}
};
@ -107,64 +180,39 @@ enum common_chat_tool_choice {
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_MAGISTRAL,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
COMMON_CHAT_FORMAT_GLM_4_5,
COMMON_CHAT_FORMAT_MINIMAX_M2,
COMMON_CHAT_FORMAT_KIMI_K2,
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,
COMMON_CHAT_FORMAT_EXAONE_MOE,
// These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE,
COMMON_CHAT_FORMAT_PEG_NATIVE,
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_thinking = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
@ -174,13 +222,14 @@ struct common_chat_params {
// per-message parsing syntax
// should be derived from common_chat_params
struct common_chat_parser_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
common_peg_arena parser = {};
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
bool debug = false; // Enable debug output for PEG parser
common_peg_arena parser = {};
common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format;
@ -193,45 +242,42 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
void common_chat_templates_free(struct common_chat_templates * tmpls);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
struct common_chat_templates_deleter {
void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); }
};
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
std::string common_chat_format_single(const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja,
const std::map<std::string, std::string> & chat_template_kwargs);
std::string common_chat_format_example(const struct common_chat_templates * tmpls,
bool use_jinja,
const std::map<std::string, std::string> & chat_template_kwargs);
const char* common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
const char * common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
// used by arg and server
const char * common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
const char * common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
@ -250,3 +296,10 @@ nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_
// get template caps, useful for reporting to server /props endpoint
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs,
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt);

View File

@ -676,7 +676,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
utf8_parse_result result = common_parse_utf8_codepoint(filename, offset);
if (result.status != utf8_parse_result::SUCCESS) {
return false;

View File

@ -104,6 +104,7 @@ enum llama_example {
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_FINETUNE,
LLAMA_EXAMPLE_FIT_PARAMS,
LLAMA_EXAMPLE_RESULTS,
LLAMA_EXAMPLE_COUNT,
};
@ -456,6 +457,8 @@ struct common_params {
bool kl_divergence = false; // compute KL divergence
bool check = false; // check rather than generate results for llama-results
bool usage = false; // print usage
bool completion = false; // print source-able completion script
bool use_color = false; // use color to distinguish generations and inputs

View File

@ -1,3 +1,4 @@
#include "log.h"
#include "value.h"
#include "runtime.h"
#include "caps.h"
@ -36,12 +37,16 @@ static void caps_try_execute(jinja::program & prog,
auto tools = ctx.get_val("tools");
bool success = false;
std::string result;
try {
jinja::runtime runtime(ctx);
runtime.execute(prog);
auto results = runtime.execute(prog);
auto parts = jinja::runtime::gather_string_parts(results);
result = parts->as_string().str();
success = true;
} catch (const std::exception & e) {
JJ_DEBUG("Exception during execution: %s", e.what());
result = "";
// ignore exceptions during capability analysis
}
@ -90,6 +95,8 @@ caps caps_get(jinja::program & prog) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};
JJ_DEBUG("%s\n", ">>> Running capability check: typed content");
// case: typed content support
caps_try_execute(
prog,
@ -120,6 +127,7 @@ caps caps_get(jinja::program & prog) {
}
);
JJ_DEBUG("%s\n", ">>> Running capability check: system prompt");
// case: system prompt support
caps_try_execute(
@ -150,7 +158,9 @@ caps caps_get(jinja::program & prog) {
}
);
// case: tools support
JJ_DEBUG("%s\n", ">>> Running capability check: single tool support");
// case: tools support: single call
caps_try_execute(
prog,
[&]() {
@ -162,10 +172,10 @@ caps caps_get(jinja::program & prog) {
},
{
{"role", "assistant"},
{"content", "Assistant message"},
{"content", ""}, // Some templates expect content to be empty with tool calls
{"tool_calls", json::array({
{
{"id", "call1"},
{"id", "call00001"},
{"type", "function"},
{"function", {
{"name", "tool1"},
@ -173,19 +183,18 @@ caps caps_get(jinja::program & prog) {
{"arg", "value"}
}}
}}
},
{
{"id", "call2"},
{"type", "function"},
{"function", {
{"name", "tool2"},
{"arguments", {
{"arg", "value"}
}}
}}
}
})}
},
{
{"role", "tool"},
{"content", "Tool response"},
{"tool_call_id", "call00001"}
},
{
{"role", "assistant"},
{"content", "The tool response was 'tool response'"}
},
{
{"role", "user"},
{"content", "User message"},
@ -199,7 +208,7 @@ caps caps_get(jinja::program & prog) {
{"name", "tool"},
{"type", "function"},
{"function", {
{"name", "tool"},
{"name", "tool1"},
{"description", "Tool description"},
{"parameters", {
{"type", "object"},
@ -224,6 +233,7 @@ caps caps_get(jinja::program & prog) {
auto & tool_name = tools->at(0)->at("function")->at("name");
caps_print_stats(tool_name, "tools[0].function.name");
caps_print_stats(tools, "tools");
if (!tool_name->stats.used) {
result.supports_tools = false;
}
@ -233,6 +243,93 @@ caps caps_get(jinja::program & prog) {
if (!tool_calls->stats.used) {
result.supports_tool_calls = false;
}
}
);
JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support");
// case: tools support: parallel calls
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"},
},
{
{"role", "assistant"},
{"content", ""}, // Some templates expect content to be empty with tool calls
{"tool_calls", json::array({
{
{"id", "call00001"},
{"type", "function"},
{"function", {
{"name", "tool1"},
{"arguments", {
{"arg", "value"}
}}
}}
},
{
{"id", "call00002"},
{"type", "function"},
{"function", {
{"name", "tool1"},
{"arguments", {
{"arg", "value"}
}}
}}
}
})}
},
{
{"role", "tool"},
{"content", "Tool response"},
{"tool_call_id", "call00001"}
},
{
{"role", "assistant"},
{"content", "The tool response was 'tool response'"}
},
{
{"role", "user"},
{"content", "User message"},
},
});
},
[&]() {
// tools
return json::array({
{
{"name", "tool"},
{"type", "function"},
{"function", {
{"name", "tool1"},
{"description", "Tool description"},
{"parameters", {
{"type", "object"},
{"properties", {
{"arg", {
{"type", "string"},
{"description", "Arg description"},
}},
}},
{"required", json::array({ "arg" })},
}},
}},
},
});
},
[&](bool success, value & messages, value & /*tools*/) {
if (!success) {
result.supports_parallel_tool_calls = false;
return;
}
auto & tool_calls = messages->at(1)->at("tool_calls");;
caps_print_stats(tool_calls, "messages[1].tool_calls");
// check for second tool call usage
auto & tool_call_1 = tool_calls->at(1)->at("function");
@ -243,6 +340,8 @@ caps caps_get(jinja::program & prog) {
}
);
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
// case: preserve reasoning content in chat history
caps_try_execute(
prog,

View File

@ -114,8 +114,10 @@ value binary_expression::execute_impl(context & ctx) {
// Logical operators
if (op.value == "and") {
JJ_DEBUG("Executing logical test: %s AND %s", left->type().c_str(), right->type().c_str());
return left_val->as_bool() ? right->execute(ctx) : std::move(left_val);
} else if (op.value == "or") {
JJ_DEBUG("Executing logical test: %s OR %s", left->type().c_str(), right->type().c_str());
return left_val->as_bool() ? std::move(left_val) : right->execute(ctx);
}
@ -838,7 +840,7 @@ value call_expression::execute_impl(context & ctx) {
for (auto & arg_stmt : this->args) {
auto arg_val = arg_stmt->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.push_back(std::move(arg_val));
args.push_back(arg_val);
}
// execute callee
value callee_val = callee->execute(ctx);

View File

@ -12,8 +12,8 @@
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <unordered_map>
namespace jinja {

View File

@ -27,11 +27,11 @@ static std::string build_repetition(const std::string & item_rule, int min_items
if (separator_rule.empty()) {
if (min_items == 1 && !has_max) {
return item_rule + "+";
} else if (min_items == 0 && !has_max) {
return item_rule + "*";
} else {
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
}
if (min_items == 0 && !has_max) {
return item_rule + "*";
}
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
}
auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
@ -41,7 +41,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
static void build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int64_t>::min();
auto has_max = max_value != std::numeric_limits<int64_t>::max();
@ -128,14 +128,14 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
if (has_min && has_max) {
if (min_value < 0 && max_value < 0) {
out << "\"-\" (";
_build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
out << ")";
return;
}
if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
out << ") | ";
min_value = 0;
}
@ -159,7 +159,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
if (has_min) {
if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
out << ") | [0] | [1-9] ";
more_digits(0, decimals_left - 1);
} else if (min_value == 0) {
@ -194,7 +194,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
}
digit_range(c, c);
out << " (";
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
out << ")";
if (c < '9') {
out << " | ";
@ -213,10 +213,10 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
more_digits(0, less_decimals);
out << " | ";
}
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
} else {
out << "\"-\" (";
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
out << ")";
}
return;
@ -232,7 +232,7 @@ struct BuiltinRule {
std::vector<std::string> deps;
};
std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"boolean", {"(\"true\" | \"false\") space", {}}},
{"decimal-part", {"[0-9]{1,16}", {}}},
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
@ -247,7 +247,7 @@ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"null", {"\"null\" space", {}}},
};
std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
{"date-time", {"date \"T\" time", {"date", "time"}}},
@ -260,22 +260,26 @@ static bool is_reserved_name(const std::string & name) {
static const std::unordered_set<std::string> RESERVED_NAMES = [] {
std::unordered_set<std::string> s;
s.insert("root");
for (const auto & p : PRIMITIVE_RULES) s.insert(p.first);
for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first);
for (const auto & p : PRIMITIVE_RULES) {
s.insert(p.first);
}
for (const auto & p : STRING_FORMAT_RULES) {
s.insert(p.first);
}
return s;
}();
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
static std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
static std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
static std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
static std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
static std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
static std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
@ -322,19 +326,19 @@ private:
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
_rules[esc_name] = rule;
return esc_name;
} else {
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
std::vector<std::string> rules;
rules.reserve(alt_schemas.size());
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
@ -398,6 +402,7 @@ private:
flush_literal();
std::vector<std::string> results;
results.reserve(ret.size());
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
@ -551,7 +556,7 @@ private:
TrieNode() : is_end_of_string(false) {}
void insert(const std::string & string) {
auto node = this;
auto *node = this;
for (char c : string) {
node = &node->children[c];
}
@ -676,7 +681,7 @@ private:
if (ks.empty()) {
return res;
}
std::string k = ks[0];
const std::string& k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k];
std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
if (first_is_optional) {
@ -779,7 +784,7 @@ public:
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
const std::string& sel = tokens[i];
if (target.is_object() && target.contains(sel)) {
target = target[sel];
} else if (target.is_array()) {
@ -802,7 +807,7 @@ public:
_refs[ref] = target;
}
} else {
for (auto & kv : n.items()) {
for (const auto & kv : n.items()) {
visit_refs(kv.value());
}
}
@ -812,7 +817,7 @@ public:
visit_refs(schema);
}
std::string _generate_constant_rule(const json & value) {
static std::string _generate_constant_rule(const json & value) {
return format_literal(value.dump());
}
@ -823,10 +828,12 @@ public:
if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
}
if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
} else if (schema_type.is_array()) {
}
if (schema_type.is_array()) {
std::vector<json> schema_types;
for (const auto & t : schema_type) {
json schema_copy(schema);
@ -834,15 +841,18 @@ public:
schema_types.push_back(schema_copy);
}
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) {
}
if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
} else if (schema.contains("enum")) {
}
if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
}
if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
std::unordered_set<std::string> required;
@ -863,11 +873,12 @@ public:
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
}
if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
std::map<std::string, size_t> enum_values;
std::string hybrid_name = name;
const std::string& hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
add_component(_refs[comp_schema["$ref"]], is_required);
@ -890,9 +901,9 @@ public:
// todo warning
}
};
for (auto & t : schema["allOf"]) {
for (const auto & t : schema["allOf"]) {
if (t.contains("anyOf")) {
for (auto & tt : t["anyOf"]) {
for (const auto & tt : t["anyOf"]) {
add_component(tt, false);
}
} else {
@ -911,7 +922,8 @@ public:
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
}
if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
if (items.is_array()) {
std::string rule = "\"[\" space ";
@ -923,27 +935,31 @@ public:
}
rule += " \"]\" space";
return _add_rule(rule_name, rule);
} else {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
}
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
}
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name);
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
}
if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
}
if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
auto prim_name = schema_format + "-string";
return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
}
if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
}
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int64_t min_value = std::numeric_limits<int64_t>::min();
int64_t max_value = std::numeric_limits<int64_t>::max();
if (schema.contains("minimum")) {
@ -958,19 +974,24 @@ public:
}
std::stringstream out;
out << "(";
_build_min_max_int(min_value, max_value, out);
build_min_max_int(min_value, max_value, out);
out << ") space";
return _add_rule(rule_name, out.str());
} else if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
} else {
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump());
return "";
}
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
}
if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
}
if (schema_type.is_null() && schema.is_object()) {
// No type constraint and no recognized structural keywords (e.g. {"description": "..."}).
// Per JSON Schema semantics this is equivalent to {} and accepts any value.
return _add_rule(rule_name, _add_primitive("value", PRIMITIVE_RULES.at("value")));
}
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump());
return "";
}
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
}
void check_errors() {
@ -985,7 +1006,7 @@ public:
std::string format_grammar() {
std::stringstream ss;
for (const auto & kv : _rules) {
ss << kv.first << " ::= " << kv.second << std::endl;
ss << kv.first << " ::= " << kv.second << '\n';
}
return ss.str();
}

View File

@ -1,14 +1,15 @@
#include "common.h"
#include "peg-parser.h"
#include "json-schema-to-grammar.h"
#include "unicode.h"
#include <nlohmann/json.hpp>
#include "common.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "unicode.h"
#include <algorithm>
#include <initializer_list>
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <regex>
#include <stdexcept>
#include <unordered_set>
@ -34,8 +35,7 @@ static bool is_hex_digit(const char c) {
// This is used in common_peg_until_parser and to build a GBNF exclusion grammar
struct trie {
struct node {
size_t depth = 0;
std::map<unsigned char, size_t> children;
std::map<uint32_t, size_t> children; // Use uint32_t to store Unicode codepoints
bool is_word;
};
@ -55,15 +55,22 @@ struct trie {
size_t current = 0; // Start at root
size_t pos = start_pos;
// LOG_DBG("%s: checking at pos %zu, sv='%s'\n", __func__, start_pos, std::string(sv).c_str());
while (pos < sv.size()) {
auto it = nodes[current].children.find(sv[pos]);
auto result = common_parse_utf8_codepoint(sv, pos);
if (result.status != utf8_parse_result::SUCCESS) {
break;
}
auto it = nodes[current].children.find(result.codepoint);
if (it == nodes[current].children.end()) {
// Can't continue matching
return match_result{match_result::NO_MATCH};
}
current = it->second;
pos++;
pos += result.bytes_consumed;
// Check if we've matched a complete word
if (nodes[current].is_word) {
@ -82,22 +89,22 @@ struct trie {
}
struct prefix_and_next {
std::string prefix;
std::string next_chars;
std::vector<uint32_t> prefix;
std::vector<uint32_t> next_chars;
};
std::vector<prefix_and_next> collect_prefix_and_next() {
std::string prefix;
std::vector<uint32_t> prefix;
std::vector<prefix_and_next> result;
collect_prefix_and_next(0, prefix, result);
return result;
}
private:
void collect_prefix_and_next(size_t index, std::string & prefix, std::vector<prefix_and_next> & out) {
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
if (!nodes[index].is_word) {
if (!nodes[index].children.empty()) {
std::string chars;
std::vector<uint32_t> chars;
chars.reserve(nodes[index].children.size());
for (const auto & p : nodes[index].children) {
chars.push_back(p.first);
@ -107,7 +114,7 @@ struct trie {
}
for (const auto & p : nodes[index].children) {
unsigned char ch = p.first;
uint32_t ch = p.first;
auto child = p.second;
prefix.push_back(ch);
collect_prefix_and_next(child, prefix, out);
@ -123,11 +130,19 @@ struct trie {
void insert(const std::string & word) {
size_t current = 0;
for (unsigned char ch : word) {
size_t pos = 0;
while (pos < word.length()) {
auto result = common_parse_utf8_codepoint(word, pos);
if (result.status != utf8_parse_result::SUCCESS) {
break;
}
uint32_t ch = result.codepoint;
pos += result.bytes_consumed;
auto it = nodes[current].children.find(ch);
if (it == nodes[current].children.end()) {
size_t child = create_node();
nodes[child].depth = nodes[current].depth + 1;
nodes[current].children[ch] = child;
current = child;
} else {
@ -286,6 +301,32 @@ struct parser_executor {
parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start)
: arena(arena), ctx(ctx), start_pos(start) {}
std::string debug_indent() const { return std::string(ctx.parse_depth * 2, ' '); }
std::string debug_input_snippet(size_t pos, size_t len = 60) const {
if (pos >= ctx.input.size()) {
return "<EOF>";
}
auto snippet = ctx.input.substr(pos, len);
// Escape newlines for display
std::string result;
for (char c : snippet) {
if (c == '\n') {
result += "\\n";
} else if (c == '\r') {
result += "\\r";
} else if (c == '\t') {
result += "\\t";
} else {
result += c;
}
}
if (pos + len < ctx.input.size()) {
result += "...";
}
return result;
}
common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos);
}
@ -308,7 +349,7 @@ struct parser_executor {
auto pos = start_pos;
for (auto i = 0u; i < p.literal.size(); ++i) {
if (pos >= ctx.input.size()) {
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
@ -323,12 +364,32 @@ struct parser_executor {
}
common_peg_parse_result operator()(const common_peg_sequence_parser & p) {
if (ctx.is_debug()) {
LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos,
debug_input_snippet(start_pos).c_str(), p.children.size());
}
ctx.parse_depth++;
auto pos = start_pos;
std::vector<common_peg_ast_id> nodes;
for (const auto & child_id : p.children) {
for (size_t i = 0; i < p.children.size(); i++) {
const auto & child_id = p.children[i];
if (ctx.is_debug()) {
fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
}
auto result = arena.parse(child_id, ctx, pos);
if (ctx.is_debug()) {
fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i,
common_peg_parse_result_type_name(result.type), result.start, result.end);
}
if (result.fail()) {
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str());
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end);
}
@ -337,28 +398,65 @@ struct parser_executor {
}
if (result.need_more_input()) {
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str());
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
}
pos = result.end;
}
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
}
common_peg_parse_result operator()(const common_peg_choice_parser & p) {
if (ctx.is_debug()) {
fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos,
debug_input_snippet(start_pos).c_str(), p.children.size());
}
ctx.parse_depth++;
auto pos = start_pos;
for (const auto & child_id : p.children) {
for (size_t i = 0; i < p.children.size(); i++) {
const auto & child_id = p.children[i];
if (ctx.is_debug()) {
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
}
auto result = arena.parse(child_id, ctx, pos);
if (ctx.is_debug()) {
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i,
common_peg_parse_result_type_name(result.type));
}
if (!result.fail()) {
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(),
common_peg_parse_result_type_name(result.type), i);
}
return result;
}
}
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str());
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
common_peg_parse_result operator()(const common_peg_repetition_parser & p) {
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos,
debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count);
}
ctx.parse_depth++;
auto pos = start_pos;
int match_count = 0;
std::vector<common_peg_ast_id> nodes;
@ -366,14 +464,26 @@ struct parser_executor {
// Try to match up to max_count times (or unlimited if max_count is -1)
while (p.max_count == -1 || match_count < p.max_count) {
if (pos >= ctx.input.size()) {
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count);
}
break;
}
auto result = arena.parse(p.child, ctx, pos);
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count,
common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size());
fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str());
}
if (result.success()) {
// Prevent infinite loop on empty matches
if (result.end == pos) {
if (ctx.is_debug()) {
fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str());
}
break;
}
@ -391,21 +501,43 @@ struct parser_executor {
nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end());
}
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(),
match_count, nodes.size());
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
}
// Child failed - stop trying
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str());
}
break;
}
// Check if we got enough matches
if (p.min_count > 0 && match_count < p.min_count) {
if (pos >= ctx.input.size() && ctx.is_partial) {
ctx.parse_depth--;
if (pos >= ctx.input.size() && ctx.is_lenient()) {
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(),
match_count, p.min_count);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes));
}
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count,
p.min_count);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
}
ctx.parse_depth--;
if (ctx.is_debug()) {
fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count,
nodes.size());
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
}
@ -434,10 +566,10 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const {
// Parse a single UTF-8 codepoint (not just a single byte)
auto result = parse_utf8_codepoint(ctx.input, start_pos);
auto result = common_parse_utf8_codepoint(ctx.input, start_pos);
if (result.status == utf8_parse_result::INCOMPLETE) {
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
@ -468,7 +600,7 @@ struct parser_executor {
// Try to match up to max_count times (or unlimited if max_count is -1)
while (p.max_count == -1 || match_count < p.max_count) {
auto result = parse_utf8_codepoint(ctx.input, pos);
auto result = common_parse_utf8_codepoint(ctx.input, pos);
if (result.status == utf8_parse_result::INCOMPLETE) {
if (match_count >= p.min_count) {
@ -476,7 +608,7 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
// Not enough matches yet
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
@ -517,7 +649,7 @@ struct parser_executor {
// Check if we got enough matches
if (match_count < p.min_count) {
if (pos >= ctx.input.size() && ctx.is_partial) {
if (pos >= ctx.input.size() && ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
@ -529,7 +661,7 @@ struct parser_executor {
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
++pos; // consume '\'
if (pos >= ctx.input.size()) {
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
@ -537,6 +669,7 @@ struct parser_executor {
switch (ctx.input[pos]) {
case '"':
case '\'':
case '\\':
case '/':
case 'b':
@ -558,7 +691,7 @@ struct parser_executor {
++pos; // consume 'u'
for (int i = 0; i < 4; ++i) {
if (pos >= ctx.input.size()) {
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
@ -589,10 +722,10 @@ struct parser_executor {
return result;
}
} else {
auto utf8_result = parse_utf8_codepoint(ctx.input, pos);
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
@ -607,7 +740,49 @@ struct parser_executor {
}
// Reached end without finding closing quote
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
auto pos = start_pos;
// Parse string content (without quotes)
while (pos < ctx.input.size()) {
char c = ctx.input[pos];
if (c == '\'') {
// Found closing quote - success (don't consume it)
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
if (c == '\\') {
auto result = handle_escape_sequence(ctx, start_pos, pos);
if (!result.success()) {
return result;
}
} else {
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
if (utf8_result.status == utf8_parse_result::INVALID) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
pos += utf8_result.bytes_consumed;
}
}
// Reached end without finding closing quote
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
@ -621,11 +796,11 @@ struct parser_executor {
size_t last_valid_pos = start_pos;
while (pos < ctx.input.size()) {
auto utf8_result = parse_utf8_codepoint(ctx.input, pos);
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
// Incomplete UTF-8 sequence
if (!ctx.is_partial) {
if (!ctx.is_lenient()) {
// Input is complete but UTF-8 is incomplete = malformed
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
@ -655,7 +830,7 @@ struct parser_executor {
last_valid_pos = pos;
}
if (last_valid_pos == ctx.input.size() && ctx.is_partial) {
if (last_valid_pos == ctx.input.size() && ctx.is_lenient()) {
// Reached the end of a partial stream, there might still be more input that we need to consume.
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos);
}
@ -694,6 +869,9 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_tag_parser & p) {
// Parse the child
if (ctx.is_debug()) {
fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str());
}
auto result = arena.parse(p.child, ctx, start_pos);
if (!result.fail()) {
@ -755,6 +933,31 @@ common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) {
return id;
}
static void bfs_node(common_peg_ast_arena &arena, std::ostringstream & oss, const common_peg_ast_node & node, int indent) {
for (int i = 0; i < indent; i++) {
oss << " ";
}
oss << "NODE " << node.id;
if (!node.rule.empty()) {
oss << " (rule " << node.rule << ")";
}
if (!node.tag.empty()) {
oss << " (tag " << node.tag << ")";
}
oss << " ['" << node.text << "']\n";
for (const auto child : node.children) {
bfs_node(arena, oss, arena.get(child), indent + 1);
}
}
std::string common_peg_ast_arena::dump() {
std::ostringstream oss;
for (auto & node : nodes_) {
bfs_node(*this, oss, node, 0);
}
return oss.str();
}
void common_peg_arena::resolve_refs() {
// Walk through all parsers and replace refs with their corresponding rule IDs
for (auto & parser : parsers_) {
@ -786,6 +989,7 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_until_parser> ||
std::is_same_v<T, common_peg_literal_parser> ||
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_space_parser>) {
@ -803,9 +1007,21 @@ void common_peg_arena::resolve_refs() {
}
std::string common_peg_arena::dump(common_peg_parser_id id) const {
std::unordered_set<common_peg_parser_id> visited;
return dump_impl(id, visited);
}
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
std::unordered_set<common_peg_parser_id> & visited) const {
// Check for cycles
if (visited.count(id)) {
return "[cycle]";
}
visited.insert(id);
const auto & parser = parsers_.at(id);
return std::visit([this](const auto & p) -> std::string {
return std::visit([this, &visited](const auto & p) -> std::string {
using T = std::decay_t<decltype(p)>;
if constexpr (std::is_same_v<T, common_peg_epsilon_parser>) {
@ -819,24 +1035,27 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const {
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
std::vector<std::string> parts;
for (const auto & child : p.children) {
parts.push_back(dump(child));
parts.push_back(dump_impl(child, visited));
}
return "Sequence(" + string_join(parts, ", ") + ")";
} else if constexpr (std::is_same_v<T, common_peg_choice_parser>) {
std::vector<std::string> parts;
for (const auto & child : p.children) {
parts.push_back(dump(child));
parts.push_back(dump_impl(child, visited));
}
return "Choice(" + string_join(parts, ", ") + ")";
} else if constexpr (std::is_same_v<T, common_peg_repetition_parser>) {
if (p.max_count == -1) {
return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)";
return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) +
", unbounded)";
}
return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
} else if constexpr (std::is_same_v<T, common_peg_and_parser>) {
return "And(" + dump(p.child) + ")";
return "And(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_not_parser>) {
return "Not(" + dump(p.child) + ")";
return "Not(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@ -848,14 +1067,20 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const {
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return "JsonString()";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return "PythonDictString()";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return "Until(" + string_join(p.delimiters, " | ") + ")";
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")";
return "Schema(" + dump_impl(p.child, visited) + ", " + (p.schema ? p.schema->dump() : "null") + ")";
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
return "Rule(" + p.name + ", " + dump(p.child) + ")";
return "Rule(" + p.name + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_ref_parser>) {
return "Ref(" + p.name + ")";
} else if constexpr (std::is_same_v<T, common_peg_tag_parser>) {
return "Tag(" + p.tag + ", " + dump(p.child) + ")";
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
return "Atomic(" + dump(p.child) + ")";
} else {
return "Unknown";
}
@ -1054,7 +1279,54 @@ common_peg_arena common_peg_parser_builder::build() {
return std::move(arena_);
}
// String primitives
common_peg_parser common_peg_parser_builder::json_string_content() {
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
}
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
}
common_peg_parser common_peg_parser_builder::double_quoted_string() {
return rule("dq-string",
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
}
common_peg_parser common_peg_parser_builder::single_quoted_string() {
return rule("sq-string",
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
}
common_peg_parser common_peg_parser_builder::flexible_string() {
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
}
// Generic helpers for object/array structure
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
const common_peg_parser & string_parser,
const common_peg_parser & value_parser) {
return rule(name, [this, string_parser, value_parser]() {
auto ws = space();
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
});
}
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
const common_peg_parser & value_parser) {
return rule(name, [this, value_parser]() {
auto ws = space();
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
});
}
// JSON parsers
common_peg_parser common_peg_parser_builder::json_number() {
return rule("json-number", [this]() {
auto digit1_9 = chars("[1-9]", 1, 1);
@ -1062,7 +1334,11 @@ common_peg_parser common_peg_parser_builder::json_number() {
auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})});
auto frac = sequence({literal("."), digits});
auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits});
return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()});
// Negative lookahead: only commit the number when the next character can't extend it.
// At EOF in partial mode, chars returns NEED_MORE → negate propagates NEED_MORE → number not committed.
// This prevents premature commits of partial numbers (e.g. "3" when "3.14" is incoming).
auto not_number_continuation = negate(chars("[0-9.eE+-]", 1, 1));
return sequence({ optional(literal("-")), int_part, optional(frac), optional(exp), not_number_continuation, space() });
});
}
@ -1085,36 +1361,11 @@ common_peg_parser common_peg_parser_builder::json_null() {
}
common_peg_parser common_peg_parser_builder::json_object() {
return rule("json-object", [this]() {
auto ws = space();
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
return sequence({
literal("{"),
ws,
choice({
literal("}"),
sequence({members, ws, literal("}")})
}),
ws
});
});
return generic_object("json-object", json_string(), json());
}
common_peg_parser common_peg_parser_builder::json_array() {
return rule("json-array", [this]() {
auto ws = space();
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
return sequence({
literal("["),
ws,
choice({
literal("]"),
sequence({elements, ws, literal("]")})
}),
ws
});
});
return generic_array("json-array", json());
}
common_peg_parser common_peg_parser_builder::json() {
@ -1130,8 +1381,40 @@ common_peg_parser common_peg_parser_builder::json() {
});
}
common_peg_parser common_peg_parser_builder::json_string_content() {
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
common_peg_parser common_peg_parser_builder::python_string() {
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
}
common_peg_parser common_peg_parser_builder::python_number() {
return json_number();
}
common_peg_parser common_peg_parser_builder::python_bool() {
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
}
common_peg_parser common_peg_parser_builder::python_null() {
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
}
common_peg_parser common_peg_parser_builder::python_dict() {
return generic_object("python-dict", python_string(), python_value());
}
common_peg_parser common_peg_parser_builder::python_array() {
return generic_array("python-array", python_value());
}
common_peg_parser common_peg_parser_builder::python_value() {
return rule("python-value", [this]() {
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
});
}
common_peg_parser common_peg_parser_builder::marker() {
auto sharp_bracket_parser = literal("<") + until(">") + literal(">");
auto square_bracket_parser = literal("[") + until("]") + literal("]");
return choice({ sharp_bracket_parser, square_bracket_parser });
}
common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) {
@ -1145,17 +1428,54 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
});
}
static std::string gbnf_escape_char_class(char c) {
switch (c) {
case '\n': return "\\n";
case '\t': return "\\t";
case '\r': return "\\r";
case '\\': return "\\\\";
case ']': return "\\]";
case '[': return "\\[";
default: return std::string(1, c);
static std::string gbnf_escape_char_class(uint32_t c) {
if (c == '-' || c == ']' || c == '[' || c == '\\') {
return "\\" + std::string(1, (char) c);
}
// Escape whitespace control characters
if (c == '\n') {
return "\\n";
}
if (c == '\t') {
return "\\t";
}
if (c == '\r') {
return "\\r";
}
// Printable ASCII
if (c >= 0x20 && c <= 0x7E) {
return std::string(1, (char) c);
}
// Hex escape
char buf[16];
const char * hex = "0123456789ABCDEF";
if (c <= 0xFF) {
buf[0] = '\\';
buf[1] = 'x';
buf[2] = hex[(c >> 4) & 0xF];
buf[3] = hex[c & 0xF];
buf[4] = '\0';
} else if (c <= 0xFFFF) {
buf[0] = '\\';
buf[1] = 'u';
buf[2] = hex[(c >> 12) & 0xF];
buf[3] = hex[(c >> 8) & 0xF];
buf[4] = hex[(c >> 4) & 0xF];
buf[5] = hex[c & 0xF];
buf[6] = '\0';
} else {
buf[0] = '\\';
buf[1] = 'U';
for (int i = 0; i < 8; i++) {
buf[2 + i] = hex[(c >> ((7 - i) * 4)) & 0xF];
}
buf[10] = '\0';
}
return std::string(buf);
}
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
@ -1173,12 +1493,12 @@ static std::string gbnf_excluding_pattern(const std::vector<std::string> & strin
std::string cls;
cls.reserve(chars.size());
for (const auto & ch : chars) {
for (uint32_t ch : chars) {
cls += gbnf_escape_char_class(ch);
}
if (!pre.empty()) {
pattern += gbnf_format_literal(pre) + " [^" + cls + "]";
pattern += gbnf_format_literal(common_unicode_cpts_to_utf8(pre)) + " [^" + cls + "]";
} else {
pattern += "[^" + cls + "]";
}
@ -1208,7 +1528,8 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_space_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_json_string_parser>) {
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
// These parsers do not have any children
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
for (auto child : p.children) {
@ -1346,6 +1667,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
if (p.delimiters.empty()) {
return ".*";
@ -1477,6 +1800,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
};
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return json{{"type", "json_string"}};
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return json{{ "type", "python_dict_string" }};
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return json{{"type", "until"}, {"delimiters", p.delimiters}};
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
@ -1606,6 +1931,9 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
if (type == "json_string") {
return common_peg_json_string_parser{};
}
if (type == "python_dict_string") {
return common_peg_python_dict_string_parser{};
}
if (type == "until") {
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {
throw std::runtime_error("until parser missing or invalid 'delimiters' field");

View File

@ -4,6 +4,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <string_view>
#include <functional>
@ -111,6 +112,8 @@ class common_peg_ast_arena {
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
std::string dump();
};
struct common_peg_parse_result {
@ -136,21 +139,43 @@ struct common_peg_parse_result {
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
};
enum common_peg_parse_flags {
COMMON_PEG_PARSE_FLAG_NONE = 0,
COMMON_PEG_PARSE_FLAG_LENIENT = 1 << 0,
COMMON_PEG_PARSE_FLAG_DEBUG = 1 << 1,
};
inline common_peg_parse_flags operator|(common_peg_parse_flags a, common_peg_parse_flags b) {
return static_cast<common_peg_parse_flags>(int(a) | int(b));
}
inline common_peg_parse_flags & operator|=(common_peg_parse_flags & a, common_peg_parse_flags b) {
return a = a | b;
}
inline common_peg_parse_flags operator&(common_peg_parse_flags a, common_peg_parse_flags b) {
return static_cast<common_peg_parse_flags>(int(a) & int(b));
}
inline common_peg_parse_flags operator~(common_peg_parse_flags a) {
return static_cast<common_peg_parse_flags>(~int(a));
}
struct common_peg_parse_context {
std::string input;
bool is_partial;
common_peg_parse_flags flags;
common_peg_ast_arena ast;
int parse_depth;
common_peg_parse_context()
: is_partial(false), parse_depth(0) {}
common_peg_parse_context(common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
: flags(flags), parse_depth(0) {}
common_peg_parse_context(const std::string & input)
: input(input), is_partial(false), parse_depth(0) {}
common_peg_parse_context(const std::string & input, common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
: input(input), flags(flags), parse_depth(0) {}
common_peg_parse_context(const std::string & input, bool is_partial)
: input(input), is_partial(is_partial), parse_depth(0) {}
bool is_lenient() const { return flags & COMMON_PEG_PARSE_FLAG_LENIENT; }
bool is_debug() const { return flags & COMMON_PEG_PARSE_FLAG_DEBUG; }
};
class common_peg_arena;
@ -207,6 +232,7 @@ struct common_peg_chars_parser {
};
struct common_peg_json_string_parser {};
struct common_peg_python_dict_string_parser {};
struct common_peg_until_parser {
std::vector<std::string> delimiters;
@ -255,6 +281,7 @@ using common_peg_parser_variant = std::variant<
common_peg_space_parser,
common_peg_chars_parser,
common_peg_json_string_parser,
common_peg_python_dict_string_parser,
common_peg_until_parser,
common_peg_schema_parser,
common_peg_rule_parser,
@ -299,6 +326,8 @@ class common_peg_arena {
friend class common_peg_parser_builder;
private:
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
common_peg_parser_id add_parser(common_peg_parser_variant parser);
void add_rule(const std::string & name, common_peg_parser_id id);
@ -311,6 +340,10 @@ class common_peg_parser_builder {
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
// Generic helpers for building object/array structures with configurable string/value parsers.
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
public:
common_peg_parser_builder();
@ -404,6 +437,21 @@ class common_peg_parser_builder {
// S -> A{n}
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
// Matches a double-quoted string: '"' content '"' space
common_peg_parser double_quoted_string();
// Matches a single-quoted string: "'" content "'" space
common_peg_parser single_quoted_string();
// Matches a string that accepts both double-quoted and single-quoted styles.
common_peg_parser flexible_string();
// Matches double-quoted string content without the surrounding quotes.
common_peg_parser json_string_content();
// Matches single-quoted string content without the surrounding quotes.
common_peg_parser single_quoted_string_content();
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
// value -> object | array | string | number | true | false | null
common_peg_parser json();
@ -414,14 +462,24 @@ class common_peg_parser_builder {
common_peg_parser json_bool();
common_peg_parser json_null();
// Matches JSON string content without the surrounding quotes.
// Useful for extracting content within a JSON string.
common_peg_parser json_string_content();
// Matches a JSON object member with a key and associated parser as the
// value.
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
// Creates a complete Python format parser supporting dicts, arrays, strings, numbers, booleans, and None.
// Differs from JSON: uses True/False/None, accepts both single and double-quoted strings.
// value -> dict | array | string | number | True | False | None
common_peg_parser python_value();
common_peg_parser python_dict();
common_peg_parser python_string();
common_peg_parser python_array();
common_peg_parser python_number();
common_peg_parser python_bool();
common_peg_parser python_null();
// A marker, i.e. text delimited by a pair of <> or []
common_peg_parser marker();
// Wraps a parser with JSON schema metadata for grammar generation.
// Used internally to convert JSON schemas to GBNF grammar rules.
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);

View File

@ -1,14 +1,18 @@
#include "unicode.h"
#include <cassert>
#include <stdexcept>
#include <vector>
#include <string>
// implementation adopted from src/unicode.cpp
size_t utf8_sequence_length(unsigned char first_byte) {
size_t common_utf8_sequence_length(unsigned char first_byte) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
return lookup[highbits];
}
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset) {
if (offset >= input.size()) {
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
}
@ -62,3 +66,43 @@ utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
// Invalid first byte
return utf8_parse_result(utf8_parse_result::INVALID);
}
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {
result.append(common_unicode_cpt_to_utf8(cps[i]));
}
return result;
}
std::string common_unicode_cpt_to_utf8(uint32_t cpt) {
std::string result;
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cpt);
return result;
}
if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
throw std::invalid_argument("invalid codepoint");
}

View File

@ -2,6 +2,8 @@
#include <cstdint>
#include <string_view>
#include <vector>
#include <string>
// UTF-8 parsing utilities for streaming-aware unicode support
@ -16,7 +18,10 @@ struct utf8_parse_result {
// Determine the expected length of a UTF-8 sequence from its first byte
// Returns 0 for invalid first bytes
size_t utf8_sequence_length(unsigned char first_byte);
size_t common_utf8_sequence_length(unsigned char first_byte);
// Parse a single UTF-8 codepoint from input
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps);
std::string common_unicode_cpt_to_utf8(uint32_t cpt);

525
docs/autoparser.md Normal file
View File

@ -0,0 +1,525 @@
# Auto-Parser Architecture
The auto-parser automatically analyzes chat templates to determine how to parse model outputs, including content, reasoning, and tool calls.
## Overview
The unified auto-parser uses a pure differential, compositional approach (inspired by the `git diff` algorithm) to analyze chat templates:
**Core Philosophy**:
- **Minimize Hardcoded Patterns**: All markers extracted through template comparison (the only heuristic is JSON detection to distinguish `JSON_NATIVE` from tag-based formats)
- **Compositional Architecture**: Separate analyzer structs for reasoning, content, and tools — each responsible for its own analysis and parser construction
**Analysis + Parser Building in Two Steps**:
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
## Data Structures
All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h).
### Top-Level: `autoparser` (main analyzer and generator)
[common/chat-auto-parser.h:367-388](common/chat-auto-parser.h#L367-L388) — top-level analysis result aggregating `jinja_caps`, `reasoning`, `content`, and `tools` sub-analyses, plus `preserved_tokens` (union of all non-empty markers).
### `analyze_reasoning`
[common/chat-auto-parser.h:254-274](common/chat-auto-parser.h#L254-L274) — reasoning analysis result: `mode` enum, `start` marker (e.g. `<think>`), and `end` marker (e.g. `</think>`).
### `analyze_content`
[common/chat-auto-parser.h:280-295](common/chat-auto-parser.h#L280-L295) — content analysis result: `mode` enum, `start`/`end` markers, and `requires_nonnull_content` flag.
### `analyze_tools` and its sub-structs
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
- [common/chat-auto-parser.h:301-361](common/chat-auto-parser.h#L301-L361) — `analyze_tools`: aggregates the four sub-structs above
### Enums
**`reasoning_mode`**: How the template handles reasoning/thinking blocks.
| Value | Description |
|-----------------|-----------------------------------------------------------------------------------|
| `NONE` | No reasoning markers detected |
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
**`content_mode`**: How the template wraps assistant content.
| Value | Description |
|--------------------------|----------------------------------------------------------------|
| `PLAIN` | No content markers |
| `ALWAYS_WRAPPED` | Content always wrapped: `<response>...</response>` |
| `WRAPPED_WITH_REASONING` | Content wrapped only when reasoning is present |
**`tool_format`**: Classification of tool call structure.
| Value | Description |
|------------------|------------------------------------------------------------------|
| `NONE` | No tool support detected |
| `JSON_NATIVE` | Pure JSON: `{"name": "X", "arguments": {...}}` |
| `TAG_WITH_JSON` | Tag-based with JSON args: `<function=X>{...}</function>` |
| `TAG_WITH_TAGGED`| Tag-based with tagged args: `<param=key>value</param>` |
**`call_id_position`**: Where call IDs appear in tag-based formats.
| Value | Description |
|--------------------------|----------------------------------------------|
| `NONE` | No call ID support detected |
| `PRE_FUNC_NAME` | Before function name |
| `BETWEEN_FUNC_AND_ARGS` | Between function name and arguments |
| `POST_ARGS` | After arguments |
## Tool Calling Formats
### JSON_NATIVE
**Structure**: The entire tool call (function name, arguments, values) is in JSON format. Optional enclosing tags around the section.
**Detection**: Function name appears inside a JSON structure (quotes preceded by `{` or `:`).
**Examples**:
Standard OpenAI-style:
```json
<tool_call>
{"name": "get_weather", "arguments": {"location": "Paris", "unit": "celsius"}}
</tool_call>
```
Mistral Nemo with array wrapper:
```json
[TOOL_CALLS]
[{"name": "calculate", "arguments": {"expr": "2+2"}}]
```
Function name as JSON key (Apertus style):
```json
{"get_weather": {"location": "Paris"}}
```
---
### TAG_WITH_JSON
**Structure**: Function name is outside JSON, in tag attributes or XML-style tags. Arguments are a JSON object.
**Detection**: Function name not in JSON, but argument names appear in JSON context.
**Examples**:
Functionary v3.1:
```xml
<function=get_weather>{"location": "Paris", "unit": "celsius"}</function>
```
MiniMax:
```xml
<minimax:tool_call>
<tool_name>calculate</tool_name>
<arguments>{"expr": "2+2"}</arguments>
</minimax:tool_call>
```
---
### TAG_WITH_TAGGED
**Structure**: Both function name and argument names are in XML-style tags. String values are unquoted; non-string values are JSON-formatted.
**Detection**: Neither function name nor argument names appear in a JSON context.
**Examples**:
Qwen/Hermes XML format:
```xml
<function=get_weather>
<param=location>Paris</param>
<param=unit>celsius</param>
</function>
```
Mixed types:
```xml
<function=calculate>
<param=expr>2+2</param>
<param=precision>2</param>
<param=options>{"round": true}</param>
</function>
```
String values (`Paris`, `celsius`, `2+2`) are unquoted; `options` (object type) is JSON-formatted.
---
## Analysis Flow
```text
autoparser::autoparser(tmpl)
|
|-- Phase 1: analyze_reasoning(tmpl, jinja_caps.supports_tool_calls)
| |-- R1: compare_reasoning_presence() — with/without reasoning_content field
| |-- R2: compare_thinking_enabled() — enable_thinking=false vs true
| '-- R3: compare_reasoning_scope() — reasoning+content vs reasoning+tools
| (only if supports_tool_calls)
|
|-- Phase 2: analyze_content(tmpl, reasoning)
| '-- C1: compares content-only vs tools output and content-only vs reasoning output
|
|-- Phase 3: analyze_tools(tmpl, jinja_caps, reasoning)
| (skipped entirely if !jinja_caps.supports_tool_calls)
| |
| |-- T1: analyze_tool_calls() — no tools vs with tools; classifies format
| | |-- JSON path → analyze_tool_call_format_json_native()
| | '-- tag path → analyze_tool_call_format_non_json()
| |
| (if format != NONE and format != JSON_NATIVE:)
| |
| |-- T2: check_per_call_markers() — 1 call vs 2 calls; moves section→per-call if needed
| | (only if supports_parallel_tool_calls)
| |
| |-- T3: extract_function_markers() — func_alpha vs func_beta; extracts name prefix/suffix/close
| |
| |-- T4: analyze_arguments() — (TAG_WITH_TAGGED only)
| | |-- A1: extract_argument_name_markers() — arg_name_A vs arg_name_B
| | '-- A2: extract_argument_value_markers() — value "XXXX" vs "YYYY"
| |
| |-- T5: extract_argument_separator() — 1 arg vs 2 args; finds separator between args
| |
| |-- T6: extract_args_markers() — 0 args vs 1 arg; finds args container markers
| |
| '-- T7: extract_call_id_markers() — call_id "call00001" vs "call99999"
|
'-- collect_preserved_tokens() — union of all non-empty markers
|
'-- apply workarounds() — post-hoc patches for edge-case templates
|
v
autoparser (analysis result)
|
v
autoparser::peg_generator::generate_parser(tmpl, inputs, analysis)
|-- analysis.build_parser(inputs) — builds PEG parser arena
| |-- reasoning.build_parser(ctx) — reasoning parser (mode-dependent)
| |-- content.build_parser(ctx) — content parser (mode-dependent)
| '-- tools.build_parser(ctx) — tool parser (dispatches by tool_format)
| |-- build_tool_parser_json_native()
| |-- build_tool_parser_tag_json()
| '-- build_tool_parser_tag_tagged()
|
|-- Build GBNF grammar (if tools present and trigger_marker non-empty)
'-- Set grammar_triggers from section_start or per_call_start
|
v
common_chat_params (prompt, parser, grammar, triggers, preserved_tokens)
```
## Entry Point
The auto-parser is invoked in [common/chat.cpp:1280-1310](common/chat.cpp#L1280-L1310) in `common_chat_templates_apply_jinja`. A few specialized templates are handled first (Ministral/Magistral Large 3, GPT-OSS with `<|channel|>`, Functionary v3.2 with `>>>all`), then the auto-parser handles everything else via `autoparser::autoparser` + `peg_generator::generate_parser`.
## Algorithm Details
### Core Mechanism: Differential Comparison
All analysis phases use the same factorized comparison function declared in [common/chat-auto-parser-helpers.h:68](common/chat-auto-parser-helpers.h#L68):
```cpp
compare_variants(tmpl, params_A, params_modifier)
```
This creates variant B by applying a modifier lambda to a copy of `params_A`, renders both through the template, and computes a `diff_split` ([common/chat-auto-parser.h:28-37](common/chat-auto-parser.h#L28-L37)):
- `prefix` — common prefix between A and B
- `suffix` — common suffix between A and B
- `left` — unique to variant A
- `right` — unique to variant B
The diff is computed via `calculate_diff_split()`, which finds the longest-common-prefix and longest-common-suffix, then iteratively moves incomplete `<...>` or `[...]` markers from the prefix/suffix into left/right until stable (tag boundary fixing).
Text is segmentized into markers and non-marker fragments using `segmentize_markers()`, which splits on `<...>` and `[...]` boundaries.
### Phase 1: Reasoning Analysis
**R1 — `compare_reasoning_presence()`**: Compares assistant message with vs without a `reasoning_content` field.
- Searches `diff.right` (output with reasoning) for the reasoning content needle
- Uses PEG parsers to find surrounding markers:
- If both pre/post markers found in `diff.right``TAG_BASED` (both tags visible in diff = no forced close)
- If both found but post marker only in the full output B → `FORCED_CLOSED`
- If only post marker found → `DELIMITER`
- Sets `reasoning.start` and `reasoning.end`
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
- Only runs if `jinja_caps.supports_tool_calls`
- Detects `TOOLS_ONLY`: reasoning content present in B (with tools) but not in A (with text content)
- Extracts reasoning markers from the tool call output using PEG parsers
### Phase 2: Content Analysis
**C1**: Two comparisons in the `analyze_content` constructor:
- Comparison 1: content-only output vs tool-call output → `diff_tools`
- Comparison 2: content-only output vs reasoning+empty-content output → `diff_reasoning`
Classification logic:
- `PLAIN`: `diff_tools.left` equals the response string (content is the entire diff, no wrapper)
- `ALWAYS_WRAPPED`: markers found surrounding the content text in `pure_content` → extracts `start`/`end`
### Phase 3: Tool Call Analysis
**T1 — `analyze_tool_calls()`**: Compares no-tools vs with-tools output.
- Extracts the tool call section as `diff.right`
- Calls `analyze_tool_call_format()` which first strips reasoning markers from the haystack, then:
- Calls `in_json_haystack()` for both function name and argument name needles
- `in_json_haystack()` uses a PEG parser to check whether the needle appears in a JSON context (preceded by `{` or `:` with surrounding quotes)
- If function name is in JSON → `JSON_NATIVE``analyze_tool_call_format_json_native()`
- If function name not in JSON, arg name is in JSON → `TAG_WITH_JSON`
- If neither in JSON → `TAG_WITH_TAGGED`
- `analyze_tool_call_format_json_native()`: parses the JSON object, matches field values to needles to populate `name_field`, `args_field`, `id_field`, `gen_id_field`; detects `tools_array_wrapped`; extracts `section_start`/`section_end`
- `analyze_tool_call_format_non_json()`: uses PEG parsers on the haystack to find up to two opening markers (section + per-call) then up to two closing markers
**T2 — `check_per_call_markers()`**: Compares 1 call vs 2 calls.
- Computes a secondary diff of the second call portion vs the common suffix
- If the second call content starts with `section_start` → the section marker is actually per-call → moves `section_start/end` to `per_call_start/end` and clears the section markers
**T3 — `extract_function_markers()`**: Compares function name `FUN_FIRST` vs `FUN_SECOND` (two different named functions).
- Finds where the function name appears in `diff.left`
- Extracts `function.name_prefix` from the common prefix up to the function marker, and `function.name_suffix` from after the name up to the next marker
- Extends `name_suffix` into `diff.suffix` (to the first marker for TAG_WITH_TAGGED; to the first `{` or `[` for TAG_WITH_JSON)
- Extracts `function.close` from after the last argument value up to the per-call/section end marker
**T4 — `analyze_arguments()`** (TAG_WITH_TAGGED only):
- **A1 `extract_argument_name_markers()`**: Compares `arg_name_A` vs `arg_name_B` (two different argument names).
- Finds shared surrounding structure → `arguments.name_prefix`, `arguments.name_suffix`
- **A2 `extract_argument_value_markers()`**: Compares argument value `"XXXX"` vs `"YYYY"` (same arg, different value).
- Finds markers surrounding the value → `arguments.value_prefix`, `arguments.value_suffix`
**T5 — `extract_argument_separator()`**: Compares 1 argument vs 2 arguments (same function).
- Uses `until_common_prefix(diff.right, ARG_FIRST, ARG_SECOND)` to find what separates the two argument blocks
**T6 — `extract_args_markers()`**: Compares 0 arguments vs 1 argument.
- Uses `until_common_prefix()` and `after_common_suffix()` with the empty and single-arg JSON strings as anchors to find container markers (`arguments.start`, `arguments.end`)
**T7 — `extract_call_id_markers()`**: Compares call IDs `"call00001"` vs `"call99999"`.
- Determines whether function name appears in `diff.prefix` or `diff.suffix` to classify position:
- Function name in prefix only → `BETWEEN_FUNC_AND_ARGS` or `POST_ARGS` (further distinguished by where `{` appears)
- Function name in suffix only → `PRE_FUNC_NAME`
- Extracts `call_id.prefix` and `call_id.suffix` markers around the call ID value
- Clears `per_call_end` if it incorrectly incorporated the call ID suffix
### Workarounds
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
5. **DeepSeek-R1-Distill-Qwen** — source contains `tool▁calls▁begin` markers: overrides tool section/per-call markers with the correct Unicode block characters
### Parser Building
Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) implements `build_parser(parser_build_context&)`. They share a `parser_build_context` that carries the PEG builder, inference inputs, the pre-built reasoning parser, and a pointer to the content analyzer.
#### Reasoning Parser (`analyze_reasoning::build_parser`)
| Mode | Parser |
|-----------------------------------|---------------------------------------------------------------------|
| Not extracting reasoning | `eps()` |
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
#### Content Parser (`analyze_content::build_parser`)
| Condition | Parser |
|----------------------------------------|---------------------------------------------------------------------------------|
| `json_schema` present | `reasoning + space() + content(schema(json(), "response-format", ...)) + end()` |
| Tools present | Dispatches to `analyze_tools::build_parser()` |
| `ALWAYS_WRAPPED` with reasoning | `reasoning + start + content(until(end)) + end + end()` |
| `ALWAYS_WRAPPED` without reasoning | `content(until(start)) + start + content(until(end)) + end + end()` |
| Default (PLAIN) | `reasoning + content(rest()) + end()` |
#### Tool Parsers (`analyze_tools::build_parser`)
Dispatches by `format.mode`:
**`build_tool_parser_json_native()`**: Calls `p.standard_json_tools()` which internally dispatches to:
- `build_json_tools_function_is_key()` — function name is the JSON key: `{"get_weather": {...}}`
- `build_json_tools_nested_keys()` — nested: `{"function": {"name": "X", "arguments": {...}}}`
- `build_json_tools_flat_keys()` — flat: `{"name": "X", "arguments": {...}}`
Handles content wrappers, array wrapping (`tools_array_wrapped`), parallel calls, and `parameter_order`.
**`build_tool_parser_tag_json()`**: For each tool function:
```text
tool_open(name_prefix + tool_name(literal(name)) + name_suffix) +
call_id_section +
tool_args(schema(json(), tool_schema))
[+ function.close if non-empty]
```
Wrapped in per-call markers (with optional parallel call repetition) then optionally in section markers.
**`build_tool_parser_tag_tagged()`**: For each tool function, builds one parser per argument:
- String types: `tool_arg_string_value(schema(until(value_suffix), ...))`
- JSON types: `tool_arg_json_value(schema(json(), ...))`
- Required args are plain; optional args wrapped in `optional()`
- Arguments joined with `space()` between consecutive parsers
For closing: uses `function.close` if present; otherwise uses `peek(per_call_end)` to avoid premature close during partial streaming; falls back to `tool_close(space())` to trigger mapper callbacks.
All three tool parsers return:
```text
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
```
### Python Dict Format
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
## Mapper
`common_chat_peg_mapper` maps PEG parse results (AST nodes) into `common_chat_msg` structures. Key design:
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
## Files
| File | Purpose |
|-------------------------------------------|----------------------------------------------------------------------|
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
| | `compare_variants()`, string helpers |
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
| `tools/parser/template-analysis.cpp` | Template analysis tool |
## Testing & Debugging
### Debug Tools
**Template Debugger**: `tools/parser/debug-template-parser.cpp`
- Usage: `./bin/llama-debug-template-parser path/to/template.jinja`
- Shows detected format, markers, generated parser, and GBNF grammar
**Template Analysis**: `tools/parser/template-analysis.cpp`
- Usage: `./bin/llama-template-analysis path/to/template.jinja`
**Debug Logging**: Enable with `LLAMA_LOG_VERBOSITY=2`
- Shows detailed analysis steps, pattern extraction results, and generated parser structure
**PEG Test Builder**: Fluent API for creating test cases — see [tests/test-chat.cpp:947-1043](tests/test-chat.cpp#L947-L1043). Example usage:
```cpp
auto tst = peg_tester("models/templates/Template.jinja");
tst.test("input text")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({tool_json})
.parallel_tool_calls(true)
.enable_thinking(true)
.expect(expected_message)
.run();
```
### Tested Templates
The following templates have active tests in `tests/test-chat.cpp`:
| Template | Format | Notes |
| -------- | ------ | ----- |
| Ministral-3-14B-Reasoning | Reasoning | `[THINK]...[/THINK]` tags (specialized handler) |
| NVIDIA-Nemotron-3-Nano-30B | TAG_WITH_TAGGED | Reasoning + tools |
| CohereForAI Command-R7B | JSON_NATIVE | `<\|START_THINKING\|>`/`<\|START_RESPONSE\|>` markers |
| Google Gemma 2 2B | Content only | No tool support |
| Qwen-QwQ-32B | Reasoning | Forced-open thinking |
| NousResearch Hermes 2 Pro | JSON_NATIVE | `<tool_call>` wrapper |
| IBM Granite 3.3 | JSON_NATIVE | `<think></think>` + `<response></response>` |
| ByteDance Seed-OSS | TAG_WITH_TAGGED | Custom `<seed:think>` and `<seed:tool_call>` tags |
| Qwen3-Coder | TAG_WITH_TAGGED | XML-style tool format |
| DeepSeek V3.1 | JSON_NATIVE | Forced thinking mode |
| GLM-4.6 | TAG_WITH_TAGGED | `<tool_call>name\n<arg_key>...<arg_value>...` format |
| GLM-4.7-Flash | TAG_WITH_TAGGED | Updated GLM format |
| Kimi-K2-Thinking | JSON_NATIVE | Reasoning + JSON tools |
| Apertus-8B-Instruct | JSON_NATIVE | Function name as JSON key |
| MiniMax-M2 | TAG_WITH_JSON | XML invoke with JSON args |
| NVIDIA-Nemotron-Nano-v2 | JSON_NATIVE | `<TOOLCALL>` wrapper (nested) |
| CohereForAI Command-R Plus | JSON_NATIVE | Markdown code block format |
| Mistral-Nemo-Instruct-2407 | JSON_NATIVE | `[TOOL_CALLS]` wrapper with ID field |
| Functionary v3.1 | TAG_WITH_JSON | `<function=X>` format |
| Functionary v3.2 | Specialized | `>>>` recipient delimiter (dedicated handler) |
| Fireworks Firefunction v2 | TAG_WITH_JSON | Fireworks tool format |
| DeepSeek R1 Distill (Llama/Qwen) | Reasoning | Forced-open thinking |
| llama-cpp-deepseek-r1 | Reasoning | Forced-open thinking |
| Kimi-K2 / Kimi-K2-Instruct | JSON_NATIVE | JSON tools with special markers |
| Llama 3.1/3.2/3.3 | JSON_NATIVE | Standard Llama tool format |
| OpenAI GPT-OSS | Specialized | Channel-based (dedicated handler) |
| Apriel 1.5 | JSON_NATIVE | `<tool_calls>` wrapper with JSON array |
| Apriel 1.6 Thinker | Reasoning | Implicit reasoning start |
| Mistral Small 3.2 | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` with call ID |
| Devstral | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` without call ID |
| StepFun 3.5 Flash | TAG_WITH_TAGGED | `<function=X><parameter=Y>` format |
## Adding Support for New Templates
To support a new template format:
1. **If it follows standard patterns** — The auto-parser should detect it automatically. Run `llama-debug-template-parser` to verify markers are correctly extracted.
2. **If differential analysis extracts incorrect markers** — Add a workaround lambda to the `workarounds` vector in `common/chat-diff-analyzer.cpp`. Inspect the template source for a unique identifying substring.
3. **If it needs fundamentally different handling** — Add a dedicated handler function in `chat.cpp` before the auto-parser block (as done for GPT-OSS, Functionary v3.2, and Ministral).
## Edge Cases and Quirks
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.

View File

@ -9,6 +9,7 @@
- [Linux](#linux)
- [Windows](#windows)
- [Environment Variable](#environment-variable)
- [Design Rule](#design-rule)
- [Known Issue](#known-issues)
- [Q&A](#qa)
- [TODO](#todo)
@ -41,6 +42,9 @@ The following releases are verified and recommended:
## News
- 2026.03
- Support Flash-Attention: less memory usage, performance impact depends on LLM.
- 2026.02
- Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. User can't build up the software for Nvidia & AMD GPU.
@ -685,18 +689,45 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_ENABLE_FLASH_ATTN | 1 (default) or 0| Enable Flash-Attention. It can reduce memory usage. The performance impact depends on the LLM.|
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Design Rule
- Open to all contributors.
- All code change should be useful to user:
- Fix bug.
- Add new function.
- Improve the performance/usage.
- Make code be easy to maintain.
- ...
- Don't accept the codes of following cases:
- Break legacy function.
- Reduce the performance of legacy case in default.
- Not completed work/the functionality cannot be demonstrated.
- Encourage to use environment variable to control features to be opened/closed.
- User can evaluate the feature without rebuild the code.
- Recommend the best features to user by setting them be opened as default.
- Design the code based on the published official releases of oneAPI packages: compiler, library, driver, OS kernel.
- Developers need to maintain the code they submit.
## Known Issues
- `Split-mode:[row]` is not supported.
- Missed the AOT (Ahead-of-Time) in buiding.
- Good: build quickly, smaller size of binary file.
- Bad: The startup is slow (JIT) in first time, but subsequent performance is unaffected.
## Q&A
- Error: `error while loading shared libraries: libsycl.so: cannot open shared object file: No such file or directory`.

View File

@ -22,7 +22,7 @@ Below is a contrived example demonstrating how to use the PEG parser to parse
output from a model that emits arguments as JSON.
```cpp
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
// Build a choice of all available tools
auto tool_choice = p.choice();
for (const auto & tool : tools) {
@ -212,7 +212,7 @@ mapper.from_ast(ctx.ast, result);
### Native
The `common_chat_peg_native_builder` builds a `native` parser suitable for
The `common_chat_peg_builder` builds a `native` parser suitable for
models that emit tool arguments as a direct JSON object.
- **`reasoning(p)`** - Tag node for `reasoning_content`
@ -225,7 +225,7 @@ models that emit tool arguments as a direct JSON object.
- **`tool_args(p)`** - Tag the tool arguments
```cpp
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto get_weather_tool = p.tool(p.sequence({
p.tool_open(p.literal("{")),
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
@ -246,7 +246,7 @@ build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
### Constructed
The `common_chat_peg_constructed_builder` builds a `constructed` parser
The `common_chat_peg_builder` builds a `constructed` parser
suitable for models that emit tool arguments as separate entities, such as XML
tags.
@ -264,7 +264,7 @@ tags.
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
```cpp
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto location_arg = p.tool_arg(
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
p.tool_arg_string_value(p.until("</parameter>")),

View File

@ -37,16 +37,17 @@ Legend:
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | | ✅ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | 🟡 | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -54,7 +55,7 @@ Legend:
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -90,9 +91,9 @@ Legend:
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
@ -100,7 +101,7 @@ Legend:
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -116,5 +117,5 @@ Legend:
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -689,6 +689,11 @@ class SchemaConverter:
elif (schema_type == 'object') or (len(schema) == 0):
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
elif schema_type is None and isinstance(schema, dict):
# No type constraint and no recognized structural keywords (e.g. {"description": "..."}).
# Per JSON Schema semantics this is equivalent to {} and accepts any value.
return self._add_rule(rule_name, self._add_primitive('value', PRIMITIVE_RULES['value']))
else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero

View File

@ -556,6 +556,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_UNARY,
@ -2463,6 +2464,15 @@ extern "C" {
bool lower,
bool uni);
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_GATED_DELTA_NET:
{
ggml_compute_forward_gated_delta_net(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
case GGML_OP_GATED_DELTA_NET:
{
n_tasks = n_threads;
} break;
@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan(
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[2]->ne[0];
cur = S_v * sizeof(float) * n_tasks;
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");

View File

@ -10380,6 +10380,190 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
}
}
// ggml_compute_forward_gated_delta_net
static void ggml_compute_forward_gated_delta_net_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
int64_t ir0,
int64_t ir1) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
GGML_ASSERT(src_beta->ne[0] == 1);
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const bool kda = (neg0 == S_v);
// scratch layout per thread: [delta(S_v)]
const int64_t scratch_per_thread = S_v;
const int ith = params->ith;
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
// output layout: [attn_scores | new_states]
// attn_scores: S_v * H * n_tokens * n_seqs floats
// new_states: S_v * S_v * H * n_seqs floats
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
const float * state_in_base = (const float *)src_state->data;
const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;
const float scale = 1.0f / sqrtf((float) S_v);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence
const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
// copy input state into output buffer and operate in-place
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
// attn output pointer for first token of this (head, seq)
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
for (int64_t t = 0; t < n_tokens; t++) {
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
}
} else {
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
}
// delta[j] = sum_i S[j][i] * k[i]
memset(delta, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
}
for (int64_t j = 0; j < S_v; ++j) {
delta[j] = (v_d[j] - delta[j]) * beta_val;
}
// outer product: S[j][i] += k[i] * delta[j]
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
}
// attn_out[j] = sum_i S[j][i] * q[i]
memset(attn_data, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
}
ggml_vec_scale_f32(S_v, attn_data, scale);
attn_data += S_v * H; // advance to next token
}
}
}
static void ggml_compute_forward_gated_delta_net_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
ggml_tensor * V = dst->src[2];
int64_t nr = V->ne[1] * V->ne[3];
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
int nth = params->nth;
int ith = params->ith;
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const int64_t dr = (nr + nchunk - 1) / nchunk;
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}
void ggml_compute_forward_gated_delta_net(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gated_delta_net_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(

View File

@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -0,0 +1,223 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t rq3,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const int64_t iq1 = h_idx / rq1;
const int64_t iq3 = sequence / rq3;
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Load state column into registers
float s[S_v];
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
const float * beta_t = beta + gb_offset;
const float * g_t = g + gb_offset * (KDA ? S_v : 1);
const float beta_val = *beta_t;
if constexpr (!KDA) {
const float g_val = expf(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += s[i] * k_t[i];
}
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = g_val * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
}
attn_data[col] = attn_col * scale;
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += expf(g_t[i]) * s[i] * k_t[i];
}
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
}
attn_data[col] = attn_col * scale;
}
attn_data += S_v * H;
}
// Write state back to global memory
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = s[i];
}
}
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3,
float scale, cudaStream_t stream) {
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
switch (S_v) {
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const int64_t n_tokens = nev2;
const int64_t n_seqs = nev3;
const bool kda = (src_g->ne[0] == S_v);
const int64_t rq1 = nev1 / neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
const float * k_d = (const float *) src_k->data;
const float * v_d = (const float *) src_v->data;
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
GGML_ASSERT(src_g->ne[0] == 1 || kda);
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// strides in floats (beta strides used for both g and beta offset computation)
const int64_t sq1 = nbq1 / sizeof(float);
const int64_t sq2 = nbq2 / sizeof(float);
const int64_t sq3 = nbq3 / sizeof(float);
const int64_t sv1 = nbv1 / sizeof(float);
const int64_t sv2 = nbv2 / sizeof(float);
const int64_t sv3 = nbv3 / sizeof(float);
const int64_t sb1 = nbb1 / sizeof(float);
const int64_t sb2 = nbb2 / sizeof(float);
const int64_t sb3 = nbb3 / sizeof(float);
const float scale = 1.0f / sqrtf((float) S_v);
cudaStream_t stream = ctx.stream();
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
}
}

View File

@ -0,0 +1,4 @@
#include "common.cuh"
#include "ggml.h"
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -53,6 +53,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/gated_delta_net.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
@ -204,7 +205,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
for (int id = 0; id < info.device_count; ++id) {
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
total_vram += prop.totalGlobalMem;
}
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
__func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
total_vram = 0;
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
for (int id = 0; id < info.device_count; ++id) {
@ -242,6 +250,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
#else
info.devices[id].supports_cooperative_launch = false;
#endif // !(GGML_USE_MUSA)
// cudaMemGetInfo returns info for the current device
size_t free_mem;
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL));
#if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock;
@ -256,22 +270,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc += prop.minor * 0x10;
}
}
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
device_vmm ? "yes" : "no", prop.warpSize,
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
#elif defined(GGML_USE_MUSA)
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info.devices[id].warp_size = 32;
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name });
@ -2733,6 +2750,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_cuda_op_gated_delta_net(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
@ -4974,6 +4994,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
//TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
#ifdef GGML_USE_MUSA
return false;
#else
return true;
#endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
case GGML_OP_CROSS_ENTROPY_LOSS:

View File

@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
return true;
}
static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
// Only support FP32 for now
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
return false;
}
// Check IO tensor shapes and dims
if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {
return false; // src0 should be effectively 3D
}
const int d_conv = src1->ne[0];
const int d_inner = src0->ne[1];
const int n_t = dst->ne[1];
const int n_s = dst->ne[2];
if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {
return false;
}
if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {
return false;
}
if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
return false;
}
// TODO: add support for non-contiguous tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
enum dspqbuf_type {
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf
return n_bufs;
}
static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_SSM_CONV;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->name.c_str();
@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;
case GGML_OP_SSM_CONV:
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_argsort(sess, op);
break;
case GGML_OP_SSM_CONV:
supp = ggml_hexagon_supported_ssm_conv(sess, op);
break;
default:
break;
}

View File

@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED
get-rows-ops.c
cpy-ops.c
argsort-ops.c
ssm-conv.c
)
target_compile_definitions(${HTP_LIB} PRIVATE

View File

@ -68,6 +68,7 @@ enum htp_op {
HTP_OP_SQR,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
HTP_OP_SSM_CONV,
INVALID
};

View File

@ -41,9 +41,6 @@ struct htp_ops_context {
worker_pool_context_t * wpool; // worker pool
uint32_t n_threads; // num threads
uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;
uint32_t flags;
};
@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
int op_ssm_conv(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */

View File

@ -15,4 +15,12 @@
#include "hvx-div.h"
#include "hvx-base.h"
#ifndef GATHER_TYPE
# if defined(__hexagon__)
# define GATHER_TYPE(_a) (intptr_t) _a
# else
# define GATHER_TYPE(_a) (HVX_Vector *) _a
# endif
#endif
#endif /* HVX_UTILS_H */

View File

@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req *
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
// We've written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup OP context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_ssm_conv(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_argsort_req(ctx, &req, bufs);
break;
case HTP_OP_SSM_CONV:
if (n_bufs != 3) {
FARF(ERROR, "Bad ssm-conv-req buffer list");
continue;
}
proc_ssm_conv_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;

View File

@ -0,0 +1,339 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "hex-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#define htp_ssm_conv_tensors_preamble \
struct htp_tensor * restrict src0 = &octx->src0; \
struct htp_tensor * restrict src1 = &octx->src1; \
struct htp_tensor * restrict dst = &octx->dst; \
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne03 = src0->ne[3]; \
\
const uint32_t ne10 = src1->ne[0]; \
const uint32_t ne11 = src1->ne[1]; \
const uint32_t ne12 = src1->ne[2]; \
const uint32_t ne13 = src1->ne[3]; \
\
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t ne3 = dst->ne[3]; \
\
const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
const uint32_t nb10 = src1->nb[0]; \
const uint32_t nb11 = src1->nb[1]; \
const uint32_t nb12 = src1->nb[2]; \
const uint32_t nb13 = src1->nb[3]; \
\
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
struct htp_ssm_conv_context {
struct htp_ops_context * octx;
uint32_t nrows_per_thread;
uint64_t t_start;
};
#define htp_ssm_conv_preamble \
struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
struct htp_ops_context * octx = scctx->octx; \
htp_ssm_conv_tensors_preamble; \
dma_queue * dma_queue = octx->ctx->dma[ith];
// Scalar FP32 SSM_CONV implementation
static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
htp_ssm_conv_preamble;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t d_conv = src1->ne[0];
const uint32_t d_inner = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension
const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension
const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension
const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension
const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension
const float * src0_data = (const float *) src0->data;
const float * src1_data = (const float *) src1->data;
float * dst_data = (float *) dst->data;
// Calculate row range for this thread
const uint32_t d_inner_per_thread = scctx->nrows_per_thread;
const uint32_t d_inner_start = d_inner_per_thread * ith;
const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner);
// No work for this thread
if (d_inner_start >= d_inner_end) {
return;
}
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) {
float sumf = 0.0f;
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq;
const uint32_t src1_idx = i0 + i1 * src1_stride_inner;
sumf += src0_data[src0_idx] * src1_data[src1_idx];
}
const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq;
dst_data[dst_idx] = sumf;
}
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end,
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
htp_ssm_conv_preamble;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const uint32_t d_conv = src1->ne[0];
const uint32_t d_inner = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
const float * src0_data = (const float *) src0->data;
const float * src1_data = (const float *) src1->data;
float * dst_data = (float *) dst->data;
// Calculate row range for this thread
const int dr = scctx->nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
const int ir = ir1 - ir0;
if (ir0 >= ir1) {
return; // No work for this thread
}
// src0 and src1 gather offsets
uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
for (uint32_t i = 0; i < VLEN_FP32; ++i) {
src0_offsets[i] = i * (ncs) * sizeof(float);
src1_offsets[i] = i * (d_conv) * sizeof(float);
}
const uint32_t src0_gather_len = VLEN * ncs;
const uint32_t src1_gather_len = VLEN * d_conv;
// gather scratchpads
HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
// copy src1 workload to VTCM
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
// FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
// copy src0 workload to VTCM
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
// FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
dma_queue_flush(dma_queue);
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
const uint32_t nvec = ir / VLEN_FP32;
const uint32_t nloe = ir % VLEN_FP32;
uint32_t i1 = 0;
for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
}
*(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
i1 += VLEN_FP32;
}
if (nloe) {
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
}
hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
}
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
int op_ssm_conv_f32(struct htp_ops_context * octx) {
htp_ssm_conv_tensors_preamble;
if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported");
return HTP_STATUS_NO_SUPPORT;
}
struct htp_ssm_conv_context scctx = { 0 };
scctx.octx = octx;
const uint32_t d_conv = src1->ne[0];
const uint32_t d_inner = src0->ne[1];
const uint32_t n_t = dst->ne[1]; // tokens per sequence
const uint32_t n_s = dst->ne[2]; // number of sequences in the batch
const uint32_t n_threads = MIN(octx->n_threads, d_inner);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t use_hvx = 0;
if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
hex_is_aligned((void *) src1->data, VLEN) &&
hex_is_aligned((void *) dst->data, VLEN);
if (is_aligned) {
use_hvx = 1;
}
}
if (use_hvx) {
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
// Compute gather scratchpad size for src0 and src1
const size_t gather_spad_size = n_threads * VLEN * 2;
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
const size_t total_spad_size =
gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
if (total_spad_size > octx->ctx->vtcm_size) {
FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
octx->ctx->vtcm_size);
use_hvx = 0;
}
}
FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0],
src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
dst->ne[1], dst->ne[2], dst->ne[3], use_hvx);
if (use_hvx) {
worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads);
} else {
worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads);
}
}
return HTP_STATUS_OK;
}
int op_ssm_conv(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
struct htp_tensor * dst = &octx->dst;
switch (dst->type) {
case HTP_TYPE_F32:
err = op_ssm_conv_f32(octx);
break;
default:
err = HTP_STATUS_NO_SUPPORT;
break;
}
return err;
}

View File

@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
char base[256];
char name[256];
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
if (mode == GGML_SCALE_MODE_BILINEAR) {
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s_aa=%d", base, antialias);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;

View File

@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_1D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:

View File

@ -83,6 +83,7 @@
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@ -890,6 +891,7 @@ typedef struct {
float sf1;
float sf2;
float sf3;
float poffs;
} ggml_metal_kargs_upscale;
typedef struct {

View File

@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
(
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_Q5_0 ||
@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q5_K ||
op->src[0]->type == GGML_TYPE_Q6_K ||
op->src[0]->type == GGML_TYPE_Q2_K ||
op->src[0]->type == GGML_TYPE_Q3_K ||
false) && (ne11 >= 4 && ne11 <= 8)
)
)
@ -3729,32 +3732,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float sf0 = (float)ne0/op->src[0]->ne[0];
const float sf1 = (float)ne1/op->src[0]->ne[1];
const float sf2 = (float)ne2/op->src[0]->ne[2];
const float sf3 = (float)ne3/op->src[0]->ne[3];
float sf0 = (float)ne0/op->src[0]->ne[0];
float sf1 = (float)ne1/op->src[0]->ne[1];
float sf2 = (float)ne2/op->src[0]->ne[2];
float sf3 = (float)ne3/op->src[0]->ne[3];
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
float poffs = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
poffs = 0.0f;
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
}
ggml_metal_kargs_upscale args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3,
/*.poffs =*/ poffs,
};
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);

View File

@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
#endif
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
template<typename T0, typename T1, short NR0, typename args_t>
void kernel_mul_mv_t_t_impl(
args_t args,
@ -4530,7 +4547,9 @@ kernel void kernel_conv_transpose_2d<half>(
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
@ -4556,6 +4575,156 @@ kernel void kernel_upscale_f32(
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
kernel void kernel_pad_f32(
constant ggml_metal_kargs_pad & args,
device const char * src0,

View File

@ -116,6 +116,7 @@ set(GGML_OPENCL_KERNELS
neg
norm
relu
l2_norm
rms_norm
rope
scale

View File

@ -497,6 +497,7 @@ struct ggml_backend_opencl_context {
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm, kernel_norm_mul_add;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
cl_kernel kernel_l2_norm_f32;
cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_diag_f32;
@ -1585,6 +1586,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// l2_norm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "l2_norm.cl.h"
};
#else
const std::string kernel_src = read_file("l2_norm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_l2_norm_f32 = clCreateKernel(prog, "kernel_l2_norm_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// rope
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -3689,6 +3707,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return true;
case GGML_OP_RMS_NORM:
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
@ -7554,6 +7574,64 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0,
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_l2_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
size_t sgs;
if (backend_ctx->gpu_family == ADRENO) {
sgs = 64;
} else if (backend_ctx->gpu_family == INTEL) {
sgs = 32;
} else {
GGML_ASSERT(false && "Unsupported GPU");
}
cl_kernel kernel = backend_ctx->kernel_l2_norm_f32;
int nth = sgs;
while (nth < ne00 && nth < (int)backend_ctx->get_kernel_workgroup_size(kernel)) {
nth *= 2;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@ -12184,6 +12262,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_rms_norm;
break;
case GGML_OP_L2_NORM:
if (!any_on_device) {
return false;
}
func = ggml_cl_l2_norm;
break;
case GGML_OP_GROUP_NORM:
if (!any_on_device) {
return false;

View File

@ -0,0 +1,71 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_l2_norm_f32(
global void * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
float eps,
local float * sum
) {
src0 = (global void*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float * y = (global float *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float sumf = 0;
// parallel sum
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
sumf += x[i00] * x[i00];
}
sumf = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
// broadcast
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
const float scale = 1.0f/sqrt(max(sum[0], eps));
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] * scale;
}
}

View File

@ -3104,6 +3104,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
}
float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
float eff_max = scale*kMaxQ;
if (eff_max <= 0) {
scales[ib] = 0;
memset(L, 0, 32);
continue;
}
float best = 0;
for (int is = -6; is <= 6; ++is) {
float id = (2*kMaxQ-1+is*0.1f)/eff_max;
@ -3273,9 +3278,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
}
float max = xval[0];
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
memset(L, 0, 16);
if (max < GROUP_MAX_EPS) {
scales[ib] = 0;
memset(L, 0, 16);
continue;
}
float best = 0;
@ -3714,9 +3719,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT
}
float max = xval[0];
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
memset(L, 0, 32);
if (max < GROUP_MAX_EPS_IQ3_XXS) {
scales[ib] = 0;
memset(L, 0, 32);
continue;
}
float best = 0;
@ -3922,6 +3927,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT
}
float max = xval[0];
for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
memset(L, 0, block_size);
if (!max) {
scales[ib] = 0;
continue;
@ -4245,6 +4251,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
if (max < GROUP_MAX_EPS_IQ1_S) {
scales[ib] = 0;
shifts[ib] = 1;
memset(L, 1, block_size);
continue;
}
@ -4285,7 +4292,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
}
}
}
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
if (besti1 < 0 || besti2 < 0 || best_shift == 0) {
scales[ib] = 0;
shifts[ib] = 1;
memset(L, 1, block_size);
continue;
}
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
@ -4429,6 +4441,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
if (max < GROUP_MAX_EPS_IQ1_M) {
scales[ib] = 0;
shifts[ib] = 0;
memset(L, 1, block_size);
continue;
}
@ -4527,7 +4540,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
}
}
}
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
if (besti1 < 0 || besti2 < 0 || best_k < 0) {
scales[ib] = 0;
shifts[ib] = 0;
memset(L, 1, block_size);
continue;
}
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
@ -4874,6 +4892,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R
}
float max = xval[0];
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
memset(L, 0, 16);
if (max < GROUP_MAX_EPS_IQ2_S) {
scales[ib] = 0;
continue;

View File

@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl
file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
file(GLOB SRCS "template-instances/fattn-tile*.cpp")
list(APPEND GGML_SOURCES_SYCL ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*.cpp")
list(APPEND GGML_SOURCES_SYCL ${SRCS})
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
if (WIN32)
@ -145,6 +150,7 @@ else()
endif()
if (GGML_SYCL_GRAPH)
message(STATUS "find GGML_SYCL_GRAPH")
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
endif()

View File

@ -23,6 +23,7 @@
#include "dequantize.hpp"
#include "dmmv.hpp"
#include "element_wise.hpp"
#include "fattn.hpp"
#include "gla.hpp"
#include "im2col.hpp"
#include "mmq.hpp"

View File

@ -19,10 +19,13 @@
#include <string>
#include "dpct/helper.hpp"
#include "ggml.h"
#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "sycl_hw.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
#if GGML_SYCL_DNNL
#include "dnnl.hpp"
@ -31,6 +34,9 @@
#define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL
#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
/* suppress warning spam */
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnested-anon-types"
@ -45,6 +51,8 @@ void ggml_sycl_host_free(void* ptr);
extern int g_ggml_sycl_debug;
extern int g_ggml_sycl_disable_optimize;
extern int g_ggml_sycl_prioritize_dmmv;
extern int g_ggml_sycl_enable_flash_attention;
#if defined(__clang__) && __has_builtin(__builtin_expect)
// Hint the optimizer to pipeline the more likely following instruction in branches
@ -170,6 +178,10 @@ static size_t g_scratch_offset = 0;
int get_current_device_id();
inline int ggml_sycl_get_device() {
return get_current_device_id();
}
inline dpct::err0 ggml_sycl_set_device(const int device) try {
int current_device_id;
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
@ -194,11 +206,14 @@ struct optimize_feature {
};
struct sycl_device_info {
int cc; // compute capability
int cc; // compute capability
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
int warp_size; // max sub_group_size of SYCL
int max_wg_per_cu; // max work groups per compute unit - refer to
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support
size_t total_vram;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
@ -435,13 +450,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
return a;
}
template <int width = WARP_SIZE>
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ int warp_reduce_sum(int x) {
return sycl::reduce_over_group(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
}
template <int width = WARP_SIZE>
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
@ -451,7 +468,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) {
return x;
}
template <int width = WARP_SIZE>
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x += dpct::permute_sub_group_by_xor(
item_ct1.get_sub_group(), x, offset);
}
return x;
}
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
@ -465,7 +494,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
return a;
}
template <int width = WARP_SIZE>
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
@ -481,7 +511,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() {
return WARP_SIZE;
}
template <int width = WARP_SIZE>
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ int warp_reduce_all(int x) {
if (width == ggml_sycl_get_physical_warp_size()) {
return sycl::all_of_group(
sycl::ext::oneapi::this_work_item::get_sub_group(),
(~0xffffffff &
(0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
.get_local_linear_id())) ||
x);
} else {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x = dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
offset, width) &&
x;
}
return x;
}
}
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ int warp_reduce_any(int x) {
if (width == ggml_sycl_get_physical_warp_size()) {
return sycl::any_of_group(
sycl::ext::oneapi::this_work_item::get_sub_group(),
(0xffffffff &
(0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
.get_local_linear_id())) &&
x);
} else {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x = dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
offset, width) ||
x;
}
return x;
}
}
/* use WARP_SIZE or WARP_32_SIZE*/
template <int width>
static __dpct_inline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
@ -629,6 +704,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) {
return sycl::uint3(mp, L, d);
}
// Maximum number of bytes that can be copied in a single instruction.
// Set by test result.
static constexpr int ggml_sycl_get_max_cpy_bytes() {
return 16;
}
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.
template <int nbytes, int alignment = 0>
static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {
if constexpr (alignment != 0) {
static_assert(nbytes % alignment == 0, "bad alignment");
}
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
#pragma unroll
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
if constexpr (nb_per_cpy == 1) {
((char *) dst)[i] = ((const char *) src)[i];
} else if constexpr (nb_per_cpy == 2) {
((short *) dst)[i] = ((const short *) src)[i];
} else if constexpr (nb_per_cpy == 4) {
((int *) dst)[i] = ((const int *) src)[i];
} else if constexpr (nb_per_cpy == 8) {
((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];
} else if constexpr (nb_per_cpy == 16) {
((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
}
}
}
template <typename T>
sycl::half2 __dpct_inline__ make_half2( T x, T y) {
sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y));
return res;
}
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
@ -636,6 +747,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va
}
template <typename T>
sycl::float2 __dpct_inline__ make_float2( T x, T y) {
sycl::float2 res(static_cast<float>(x),static_cast<float>(y));
return res;
}
sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {
sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
return float2_value;
}
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
@ -659,5 +781,97 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
return result;
}
sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {
sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
return float2_value;
}
float __dpct_inline__ __half2float(sycl::half H) {
return static_cast<float>(H);
}
static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {
acc += v*u;
}
static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {
acc += v.x() * u.x();
acc += v.y() * u.y();
}
static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {
#ifdef GGML_SYCL_F16
const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>();
acc += tmp.x() + tmp.y();
#else
const sycl::float2 tmpv = __half22float2(v);
const sycl::float2 tmpu = __half22float2(u);
acc += tmpv.x() * tmpu.x();
acc += tmpv.y() * tmpu.y();
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {
#ifdef GGML_SYCL_F16
acc += v*u;
#else
const sycl::float2 tmpv = __half22float2(v);
const sycl::float2 tmpu = __half22float2(u);
sycl::float2 tmpacc = __half22float2(acc);
// tmpacc.x += tmpv.x() * tmpu.x();
// tmpacc.y += tmpv.y() * tmpu.y();
sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());
acc = make_half2(tmp1.x(), tmp1.y());
#endif // GGML_SYCL_F16
}
template <int n>
struct ggml_sycl_unroll {
template <typename Func, typename... Args>
void operator()(const Func & f, Args... args) const {
f(n - 1, args...);
ggml_sycl_unroll<n - 1>{}(f, args...);
}
};
template <>
struct ggml_sycl_unroll<1> {
template <typename Func, typename... Args>
void operator()(const Func & f, Args... args) const {
f(0, args...);
}
};
static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {
sycl::half2 ret;
reinterpret_cast<sycl::half &>(ret.x()) =
sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
reinterpret_cast<sycl::half &>(ret.y()) =
sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
return ret;
}
static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {
return sycl::vec<float, 1>(
sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0],
sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0]))
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
}
static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {
const uint32_t mask_low = 0x0000FFFF * (float(a[0]) > float(b[0]));
const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));
return mask_low | mask_high;
}
static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();
}
static bool fast_fp16_available(const int cc) {
GGML_UNUSED(cc);
return true; //Intel GPUs always support FP16.
}
#endif // GGML_SYCL_COMMON_HPP

View File

@ -482,6 +482,63 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
});
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));
if (i00 >= ne00) {
return;
}
const int64_t i01 = item_ct1.get_group(1);
const int64_t i02 = item_ct1.get_group(0) % ne02;
const int64_t i03 = item_ct1.get_group(0) / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize
#ifdef GGML_SYCL_F16
sycl::half2 v;
#else
sycl::float2 v;
#endif
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_sycl_cast<dst_t>(v.x());
y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y());
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_nc_sycl(const void * vx,
dst_t * y,
const int64_t ne00,
const int64_t ne01,
const int64_t ne02,
const int64_t ne03,
const int64_t s01,
const int64_t s02,
const int64_t s03,
dpct::queue_ptr stream) {
const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,
ne02 * ne03);
stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03);
});
}
template <typename src_t, typename dst_t>
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
@ -662,7 +719,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
}
}
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_nc_sycl<float>;
@ -670,6 +728,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
case GGML_TYPE_BF16:
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
#endif
case GGML_TYPE_Q4_0:
return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>;
default:
return nullptr;
}

View File

@ -29,6 +29,21 @@ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne0
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);
template<typename dst_t, typename src_t>
inline dst_t ggml_sycl_cast(src_t x) {
if constexpr (std::is_same_v<dst_t, src_t>) {
return x;
} else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::bfloat16(float(x));
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
return static_cast<float>(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
return float(x);
}
}
#endif // GGML_SYCL_CONVERT_HPP

View File

@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
nequal = warp_reduce_sum<WARP_SIZE>(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;

View File

@ -2997,6 +2997,778 @@ namespace dpct
return 0;
}
template <int n_nondefault_params, int n_default_params, typename T>
class args_selector;
/// args_selector is a helper class for extracting arguments from an
/// array of pointers to arguments or buffer of arguments to pass to a
/// kernel function.
///
/// \param R(Ts...) The type of the kernel
/// \param n_nondefault_params The number of nondefault parameters of the
/// kernel (excluding parameters that like sycl::nd_item, etc.) \param
/// n_default_params The number of default parameters of the kernel
///
/// Example usage:
/// With the following kernel:
/// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
/// f=.1) {}
/// and with the declaration:
/// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
/// we have:
/// selector.get<0>() returns a reference to sycl::float*,
/// selector.get<1>() returns a reference to int,
/// selector.get<2>() returns a reference to float
template <int n_nondefault_params, int n_default_params, typename R,
typename... Ts>
class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
private:
void **kernel_params;
char *args_buffer;
template <int i> static constexpr int account_for_default_params() {
constexpr int n_total_params = sizeof...(Ts);
if constexpr (i >= n_nondefault_params) {
return n_total_params - n_default_params +
(i - n_nondefault_params);
} else {
return i;
}
}
public:
/// Get the type of the ith argument of R(Ts...)
/// \param [in] i Index of parameter to get
/// \returns Type of ith parameter
template <int i>
using arg_type = std::tuple_element_t<account_for_default_params<i>(),
std::tuple<Ts...>>;
static constexpr int params_num = sizeof...(Ts);
private:
template <int i> static constexpr int get_offset() {
if constexpr (i == 0) {
// we can assume args_buffer is properly aligned to the
// first argument
return 0;
} else {
constexpr int prev_off = get_offset<i - 1>();
constexpr int prev_past_end =
prev_off + sizeof(arg_type<i - 1>);
using T = arg_type<i>;
// is the past-the-end of the i-1st element properly aligned
// with the ith element's alignment?
if constexpr (prev_past_end % alignof(T) == 0) {
return prev_past_end;
}
// otherwise bump prev_past_end to match alignment
else {
return prev_past_end +
(alignof(T) - (prev_past_end % alignof(T)));
}
}
}
static char *get_args_buffer(void **extra) {
if (!extra)
return nullptr;
for (; (std::size_t)*extra != 0; ++extra) {
if ((std::size_t)*extra == 1) {
return static_cast<char *>(*(extra + 1));
}
}
return nullptr;
}
public:
/// If kernel_params is nonnull, then args_selector will
/// extract arguments from kernel_params. Otherwise, it
/// will extract them from extra.
/// \param [in] kernel_params Array of pointers to arguments
/// a or null pointer.
/// \param [in] extra Array containing pointer to argument buffer.
args_selector(void **kernel_params, void **extra)
: kernel_params(kernel_params),
args_buffer(get_args_buffer(extra)) {}
/// Get a reference to the ith argument extracted from kernel_params
/// or extra.
/// \param [in] i Index of argument to get
/// \returns Reference to the ith argument
template <int i> arg_type<i> &get() {
if (kernel_params) {
return *static_cast<arg_type<i> *>(kernel_params[i]);
} else {
return *reinterpret_cast<arg_type<i> *>(args_buffer +
get_offset<i>());
}
}
}; // COPY from DPCT head file
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
/// Utility class for launching SYCL kernels through kernel
/// function wrapper.
/// For example:
/// A SYCL kernel function:
/// void kernel_func(int *ptr, sycl::nd_item<3> item);
/// Kernel function wrapper:
/// void kernel_func_wrapper(int *ptr) {
/// sycl::queue queue = *dpct::kernel_launcher::_que;
/// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
/// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
/// queue.parallel_for(
/// nr,
/// [=](sycl::nd_item<3> item_ct1) {
/// kernel_func(ptr, item_ct1);
/// });
/// }
/// Then launch the kernel through wrapper like:
/// typedef void(*fpt)(int *);
/// fpt fp = kernel_func_wrapper;
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
/// device_ptr);
/// If the origin function type is erased, then need to register it first:
/// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
/// 0, 0);
class kernel_launcher {
template <typename FuncT, typename ArgSelector, std::size_t... Index>
static void launch_helper(FuncT &&func, ArgSelector &selector,
std::index_sequence<Index...>) {
func(selector.template get<Index>()...);
}
static void set_execution_config(dim3 group_range, dim3 local_range,
unsigned int local_mem_size,
queue_ptr que) {
if (que) {
_que = que;
} else {
_que = &get_default_queue();
}
_nr = sycl::nd_range<3>(
static_cast<sycl::range<3>>(group_range * local_range),
static_cast<sycl::range<3>>(local_range));
_local_mem_size = local_mem_size;
};
static inline std::mutex kernel_function_ptr_map_mutex;
public:
/// Variables for storing execution configuration.
static inline thread_local sycl::queue *_que = nullptr;
static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
static inline thread_local unsigned int _local_mem_size = 0;
/// Map for retrieving launchable functor from a raw pointer.
static inline std::map<
const void *,
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
kernel_function_ptr_map = {};
/// Registers a kernel function pointer with a corresponding launchable
/// functor.
/// \param [in] func Pointer to the kernel function.
/// \param [in] launcher Functor to handle kernel invocation.
static void register_kernel_ptr(
const void *func,
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
launcher) {
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
kernel_function_ptr_map[func] = std::move(launcher);
}
/// Launches a kernel function with arguments provided directly through
/// kernel function wrapper.
/// \tparam FuncT Type of the kernel function wrapper.
/// \tparam ArgsT Types of kernel arguments.
/// \param [in] func Pointer to the kernel function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] local_mem_size The size of local memory required by the
/// kernel function. \param [in] que SYCL queue used to execute kernel.
/// \param [in] args Kernel arguments.
template <typename FuncT, typename... ArgsT>
static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
launch(FuncT *func, dim3 group_range, dim3 local_range,
unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
set_execution_config(group_range, local_range, local_mem_size, que);
func(args...);
}
/// Launches a kernel function through registered kernel function
/// wrapper. \param [in] func Pointer to the registered kernel function
/// wrapper. \param [in] group_range SYCL group range. \param [in]
/// local_range SYCL local range. \param [in] args Array of pointers to
/// kernel arguments. \param [in] local_mem_size The size of local
/// memory required by the kernel function. \param [in] que SYCL queue
/// used to execute kernel.
static void launch(const void *func, dim3 group_range, dim3 local_range,
void **args, unsigned int local_mem_size,
queue_ptr que) {
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
auto Iter = kernel_function_ptr_map.find(func);
if (Iter == kernel_function_ptr_map.end()) {
throw std::runtime_error("dpct::launch() : no registered "
"kernel function wrapper found.");
}
(Iter->second)(group_range, local_range, args, local_mem_size, que);
}
/// Launches a kernel function with packed arguments through kernel
/// function wrapper.
/// \tparam FuncT Type of the kernel function wrapper.
/// \param [in] func Pointer to the kernel function wrapper.
/// \param [in] group_range SYCL group range.
/// \param [in] local_range SYCL local range.
/// \param [in] args Array of pointers to kernel arguments.
/// \param [in] local_mem_size The size of local memory required by the
/// kernel function. \param [in] que SYCL queue used to execute kernel.
template <typename FuncT>
static std::enable_if_t<std::is_function_v<FuncT>, void>
launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
unsigned int local_mem_size, queue_ptr que) {
constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
set_execution_config(group_range, local_range, local_mem_size, que);
args_selector<p_num, p_num, FuncT> selector(args, nullptr);
launch_helper(func, selector, std::make_index_sequence<p_num>{});
}
}; // COPY from DPCT head file
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
template <typename T>
T select_from_sub_group(
sycl::sub_group g,
T x,
int remote_local_id,
int logical_sub_group_size = 32) {
unsigned int start_index = g.get_local_linear_id() /
logical_sub_group_size *
logical_sub_group_size;
return sycl::select_from_group(
g, x, start_index + remote_local_id % logical_sub_group_size);
}
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
template <typename T>
void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
int lane = sg.get_local_linear_id();
int lane_group8_row = lane / 8;
int lane_group8_col = lane % 8;
if (!trans) {
// calculate the source lane
int src_lane = 2 * lane_group8_row;
if (lane_group8_col >= 4)
src_lane += 1;
// Broadcast the address from the source lane
auto recv_addr_uintp =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
// Cast the received address from uintptr_t to the type of 'm'
auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
// Non-transposed load
*m = recv_addr[lane_group8_col % 4];
} else {
// calculate the source lane
int src_lane = (lane % 4) * 2;
// Broadcast the address from the source lane
auto recv_addr_uintp_1 =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
auto recv_addr_uintp_2 =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
// Cast the received address from uintptr_t to 'half *'
auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
// Transposed load
int index = lane / 4;
sycl::half val0 = recv_addr_1[index];
sycl::half val1 = recv_addr_2[index];
// Combine the two 16-bits into one 32-bit value
sycl::half2 val = sycl::half2(val0, val1);
*m = *reinterpret_cast<T*>(&val);
}
}
template <typename T>
void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
// Load 1st matrix
ldmatrix(addr, m1, trans, 0);
// Load 2nd matrix
ldmatrix(addr, m2, trans, 1);
}
template <typename T>
void ldmatrix(
uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
// Load 1st matrix
ldmatrix(addr, m1, trans, 0);
// Load 2nd matrix
ldmatrix(addr, m2, trans, 1);
// Load 3rd matrix
ldmatrix(addr, m3, trans, 2);
// Load 4th matrix
ldmatrix(addr, m4, trans, 3);
}
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
/// A helper struct that defines the pack type for the input matrix
/// fragments
/// of mma() function based on the type of input matrix fragments.
/// The MMAType struct is specialized for different types of input matrices.
/// Currently, the specialization for f16, bf16 and s8 types is defined
/// below. \tparam [in] T The type of the input matrix fragments
template <typename T>
struct MMAType {
using PackType = uint32_t;
};
/// Each work item of a sub-group (limited to size 32) calling this function
/// calculates a subset fragment for the output matrix D using MAD operation
/// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
/// types:
/// - m8n8k4 (f32.f16.f16.f32)
/// - m8n8k16 (s32.s8.s8.s32)
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
/// - m16n8k32 (s32.s8.s8.s32)
/// Here, m, n & k define the shapes of A, B & C matrices respectively
/// (A = [m x k], B = [k x n], C = [m x n]).
/// \tparam [in] M The rows of A, C & D matrices
/// \tparam [in] N The columns of B, C, D matrices
/// \tparam [in] K The columns & rows of A & B matrices respectively
/// \tparam [in] ABType The type of the input matrix (A & B) fragment
/// \tparam [in] CDType The type of the output matrix (C & D) fragment
/// \param [out] d_mat_frag The fragment of the output matrix D to store the
/// result of A * B + C
/// \param [in] a_mat_frag The fragment of the input matrix A to be
/// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
/// the input matrix B to be multiplied with A matrix fragment \param [in]
/// c_mat_frag The fragment of the input matrix C to be added with the
/// result of A * B fragments
template <int M, int N, int K, typename ABType, typename CDType>
void mma(
volatile void** d_mat_frag,
void* a_mat_frag,
void* b_mat_frag,
void* c_mat_frag) {
auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
auto a =
reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
auto b =
reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
auto c = reinterpret_cast<CDType*>(c_mat_frag);
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
int lane = sg.get_local_linear_id();
static_assert(
(M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
(M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
(M == 16 && N == 8 && K == 32),
"Unsupported MMA shape!");
short row_load_offset = 4 * (lane >> 2);
short col_load_offset = 8 * (lane % 4);
if constexpr (M == 8 && N == 8 && K == 4) {
if constexpr (std::is_floating_point_v<CDType>) {
col_load_offset = row_load_offset % 16;
// Init D matrix with fragments of C matrix
*d[0] = c[0];
*d[1] = c[1];
*d[2] = c[2];
*d[3] = c[3];
*d[4] = c[4];
*d[5] = c[5];
*d[6] = c[6];
*d[7] = c[7];
// Calculate the row and col offset indices to iterate through the row
// & col fragments of A & B matrices
int r_ind = (lane % 2) ? 1 : 0;
int c_ind = ((lane % 4) / 2) ? 2 : 0;
// Each sub-group is responsible for computing a fragment size of 8*8
// elements of matrix D for each of 4 MMA computations.
// Each work item computes 8 elements of matrix D by gathering
// their corresponding col & row matrix fragments of length k (4)
// from A & B matrices respectively using below mapping logic:
// row0 = (i % 4) if (lane < 16) else (i % 4) + 4
// col0 = (lane % 4)
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
for (int i = 0; i < 4; i++) {
// Load partial fragment from col0 of matrix A ({a0, a1})
recv_a[0] =
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from col0 of matrix A ({a2, a3})
recv_a[1] =
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from row0 of matrix B ({b0, b1})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from row0 of matrix B ({b2, b3})
recv_b[1] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
auto ra = reinterpret_cast<ABType*>(recv_a);
auto rb = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment (for
// even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
// a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
// * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
// b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
// d3 += col1{ a3 } * row0{ b3 }
*d[0] +=
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
*d[1] += static_cast<float>(ra[r_ind]) *
static_cast<float>(rb[c_ind + 1]);
*d[2] += static_cast<float>(ra[r_ind + 2]) *
static_cast<float>(rb[c_ind]);
*d[3] += static_cast<float>(ra[r_ind + 2]) *
static_cast<float>(rb[c_ind + 1]);
// Load partial fragment from row1 of matrix B ({b0, b1})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
// Load partial fragment from row1 of matrix B ({b2, b3})
recv_b[1] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
// (for even work item indices)
// d0 += col0{ a0 } * row1{ b0 }
// d1 += col0{ a0 } * row1{ b1 }
// d2 += col1{ a2 } * row1{ b0 }
// d3 += col1{ a2 } * row1{ b1 }
// (for odd work item indices)
// d0 += col0{ a1 } * row1{ b2 }
// d1 += col0{ a1 } * row1{ b3 }
// d2 += col1{ a3 } * row1{ b2 }
// d3 += col1{ a3 } * row1{ b3 }
*d[4] +=
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
*d[5] += static_cast<float>(ra[r_ind]) *
static_cast<float>(rb[c_ind + 1]);
*d[6] += static_cast<float>(ra[r_ind + 2]) *
static_cast<float>(rb[c_ind]);
*d[7] += static_cast<float>(ra[r_ind + 2]) *
static_cast<float>(rb[c_ind + 1]);
}
}
} else if constexpr (M == 8 && N == 8 && K == 16) {
if constexpr (std::is_integral_v<ABType>) {
// Init D matrix with fragments of C matrix
*d[0] = c[0];
*d[1] = c[1];
// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 2 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (16)
// from A & B matrices respectively using below mapping logic:
// row0 = ((lane % 4) * 4) + i
// col0 = (lane >> 2)
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a, recv_b[2];
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
recv_b[1] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
auto a = reinterpret_cast<ABType*>(&recv_a);
auto b = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
// a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
// col1{ b0, b1, b2, b3 }
for (int j = 0; j < 4; j++) {
*d[0] += a[j] * b[j];
*d[1] += a[j] * b[j + 4];
}
}
}
} else if constexpr (M == 16 && N == 8 && K == 8) {
if constexpr (std::is_floating_point_v<CDType>) {
// Init D matrix fragment with C matrix fragment
*d[0] = c[0];
*d[1] = c[1];
*d[2] = c[2];
*d[3] = c[3];
// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 4 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (8)
// from A & B matrices respectively using below mapping logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = (lane % 4) * 2 + (i & 0x1)
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
// Load partial fragment from row0 of matrix A ({a0, a1})
recv_a[0] =
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a2, a3})
recv_a[1] =
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b0, b1})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b0, b1})
recv_b[1] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
auto ra = reinterpret_cast<ABType*>(recv_a);
auto rb = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
// b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
// } * col1{ b0, b1 }
for (int j = 0; j < 2; j++) {
*d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
*d[1] +=
static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
*d[2] +=
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
*d[3] +=
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
}
}
}
} else if constexpr (M == 16 && N == 8 && K == 16) {
if constexpr (std::is_floating_point_v<CDType>) {
// Init D matrix fragment with C matrix fragment
*d[0] = c[0];
*d[1] = c[1];
*d[2] = c[2];
*d[3] = c[3];
// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 4 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (8)
// from A & B matrices respectively using below mapping logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
// Load partial fragment from row0 of matrix A ({a0, a1})
recv_a[0] =
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from row0 of matrix A ({a2, a3})
recv_a[1] =
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a0, a1})
recv_a[2] =
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a2, a3})
recv_a[3] =
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b0, b1})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col0 of matrix B ({b2, b3})
recv_b[1] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b0, b1})
recv_b[2] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
// Load partial fragment from col1 of matrix B ({b2, b3})
recv_b[3] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
auto ra = reinterpret_cast<ABType*>(recv_a);
auto rb = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
// a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
// col1{ b0, b1, b2, b3 }
for (int j = 0; j < 4; j++) {
*d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
*d[1] +=
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
*d[2] +=
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
*d[3] += static_cast<CDType>(ra[j + 4]) *
static_cast<CDType>(rb[j + 4]);
}
}
} else if constexpr (std::is_integral_v<ABType>) {
// Init D matrix with fragments of C matrix
*d[0] = c[0];
*d[1] = c[1];
*d[2] = c[2];
*d[3] = c[3];
// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 4 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (8)
// from A & B matrices respectively using below mapping logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
// As each row & col fragment of A & B matrices is distributed across
// 4 work items, each iteration of below loop loads a partial fragment
// of matrix A (row) and matrix B (col) using the row & col offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
recv_a[0] =
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
recv_a[1] =
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
recv_b[1] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
auto ra = reinterpret_cast<ABType*>(recv_a);
auto rb = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
// a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
// col1{ b4, b5, b6, b7 }
for (int i = 0; i < 4; i++) {
*d[0] += ra[i] * rb[i];
*d[1] += ra[i] * rb[i + 4];
*d[2] += ra[i + 4] * rb[i];
*d[3] += ra[i + 4] * rb[i + 4];
}
}
}
} else if constexpr (M == 16 && N == 8 && K == 32) {
if constexpr (std::is_integral_v<ABType>) {
// Init D matrix with fragments of C matrix
*d[0] = c[0];
*d[1] = c[1];
*d[2] = c[2];
*d[3] = c[3];
// Each sub-group is responsible for computing a fragment size of 16*8
// elements of matrix D.
// Each work item computes 4 elements of matrix D by gathering
// their corresponding row & col matrix fragments of length k (32)
// from A & B matrices respectively using below mapping logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
// & 0x3) As each row & col fragment of A & B matrices is distributed
// across 4 work items, each iteration of below loop loads a partial
// fragment of matrix A (row) and matrix B (col) using the row & col
// offsets.
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
recv_a[0] =
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
recv_a[1] =
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
recv_b[0] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
recv_b[1] =
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
auto a = reinterpret_cast<ABType*>(recv_a);
auto b = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
// col1{ b0, b1, b2, b3 }
for (int j = 0; j < 4; j++) {
*d[0] += a[j] * b[j];
*d[1] += a[j] * b[j + 4];
*d[2] += a[j + 4] * b[j];
*d[3] += a[j + 4] * b[j + 4];
}
}
for (int i = 0; i < 4; i++) {
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
// Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
recv_a[0] =
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
// Load partial fragment from row1 of matrix A ({a12, a13, a14,
// a15})
recv_a[1] =
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
// Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
recv_b[0] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
recv_b[1] =
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
auto a = reinterpret_cast<ABType*>(recv_a);
auto b = reinterpret_cast<ABType*>(recv_b);
// Each work item calculates a partial product of A & B matrix
// fragments and adds it to the corresponding D matrix fragment d0
// += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
// a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
// a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
// a15 } * col1{ b4, b5, b6, b7 }
for (int j = 0; j < 4; j++) {
*d[0] += a[j] * b[j];
*d[1] += a[j] * b[j + 4];
*d[2] += a[j + 4] * b[j];
*d[3] += a[j + 4] * b[j + 4];
}
}
}
}
}
} // COPY from DPCT head files
#endif // GGML_SYCL_DPCT_HELPER_HPP

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,55 @@
#include <sycl/sycl.hpp>
#include <sycl/ext/oneapi/work_group_static.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "fattn-common.hpp"
#include "fattn-tile.hpp"
#include <cmath>
#include <float.h>
namespace syclex = sycl::ext::oneapi::experimental;
void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
switch (K->ne[0]) {
case 40: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
} break;
case 64: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
} break;
case 72: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
} break;
case 80: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
} break;
case 96: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
} break;
case 112: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
} break;
case 128: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
case 256: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,667 @@
#ifndef GGML_SYCL_FATTN_VEC_HPP
#define GGML_SYCL_FATTN_VEC_HPP
#include <sycl/sycl.hpp>
#include <sycl/ext/oneapi/work_group_static.hpp>
#include <iostream>
#include <iomanip>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "ggml.h"
#include "fattn-common.hpp"
#include <cmath>
#include <float.h>
namespace syclex = sycl::ext::oneapi::experimental;
static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {
return 128;
GGML_UNUSED(cc);
}
static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {
return 128;
}
// Currenlty llvm with the amdgcn target dose not support unrolling loops
// that contain a break that can not be resolved at compile time.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int D,
int ncols,
int type_K,
int type_V,
bool use_logit_softcap,
int warp_size> // D == head size
static void flash_attn_ext_vec(const char* __restrict__ Q,
const char* __restrict__ K,
const char* __restrict__ V,
const char* __restrict__ mask,
const char* __restrict__ sinks,
const int* __restrict__ KV_max,
float* __restrict__ dst,
sycl::float2* __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00,
const sycl::uint3 ne01,
const int32_t ne02,
const int32_t ne03,
const int32_t nb01,
const int32_t nb02,
const int32_t nb03,
const int32_t ne10,
const int32_t ne11,
const int32_t ne12,
const int32_t ne13,
const int32_t nb11,
const int32_t nb12,
const int64_t nb13,
const int32_t nb21,
const int32_t nb22,
const int64_t nb23,
const int32_t ne31,
const int32_t ne32,
const int32_t ne33,
const int32_t nb31,
const int32_t nb32,
const int64_t nb33) {
#ifdef SYCL_FLASH_ATTN
// Skip unused kernel variants for faster compilation:
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);
constexpr int nthreads_V_q = (D/4 < warp_size ? D/4 : warp_size);
constexpr int nthreads = ggml_sycl_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K");
static_assert(warp_size % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = warp_size / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
#ifdef GGML_SYCL_F16
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>();
#else
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
#endif // GGML_SYCL_F16
const int ic0 = item_ct1.get_group(2) * ncols; // Index of the Q/QKV column to work on.
const int sequence = item_ct1.get_group(0) / ne02;
const int head = item_ct1.get_group(0) - sequence * ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64.");
constexpr int nwarps = nthreads / warp_size;
const int tid = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);
__builtin_assume(tid < nthreads);
constexpr int ne_KQ = ncols*D;
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
constexpr size_t lsm_size1 = ncols * warp_size;
constexpr size_t lsm_size2 = ncols * warp_size;
#ifdef GGML_SYCL_F16
sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };
constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);
syclex::work_group_static<char[local_share_mem_size]> lsm;
float *KQ_max_shared = (float *)&lsm;
float *KQ_sum_shared = KQ_max_shared+lsm_size1;
sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);
#else
sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);
syclex::work_group_static<char[local_share_mem_size]> lsm;
float *KQ_max_shared = (float *)&lsm;
float *KQ_sum_shared = KQ_max_shared+lsm_size1;
float* KQ = KQ_sum_shared + lsm_size2;
#endif // GGML_SYCL_F16
float KQ_max[ncols];
float KQ_sum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max[j] = -FLT_MAX/2.0f;
KQ_sum[j] = 0.0f;
}
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
#ifdef GGML_SYCL_F16
sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}}; // Will be initialized completely.
#else
sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // GGML_SYCL_F16
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];
if constexpr (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + item_ct1.get_local_id(1);
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 1 && ic0 + j >= int(ne01.z())) {
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {
const int i = i0 + item_ct1.get_local_id(2);
if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0;
}
}
if (item_ct1.get_local_id(2) < D/QK8_1) {
tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);
}
} else {
const float * Q_f = (const float *) (Q + j*nb01);
constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size>
(Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
}
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
const int i =
i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);
Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
} else {
#ifdef GGML_SYCL_F16
const sycl::half2 scale_h2 = sycl::half2(scale, scale);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :
item_ct1.get_local_id(2) % nthreads_KQ) *
cpy_ne;
sycl::float2 tmp[cpy_ne] = {
{ 0.0f, 0.0f }
};
if (ncols == 1 || ic0 + j < int(ne01.z())) {
ggml_sycl_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
}
#pragma unroll
for (int i1 = 0; i1 < cpy_ne; ++i1) {
Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());
}
}
#pragma unroll
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
Q_reg[j][k] *= scale_h2;
}
}
#else
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;
if (ncols == 1 || ic0 + j < int(ne01.z())) {
ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
}
}
#pragma unroll
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
Q_reg[j][k].x() *= scale;
Q_reg[j][k].y() *= scale;
}
}
#endif // GGML_SYCL_F16
}
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
K += item_ct1.get_group(1) * nthreads * nb11;
V += item_ct1.get_group(1) * nthreads * nb21;
maskh += item_ct1.get_group(1) * nthreads;
for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;
k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,
// Increment pointers after each loop:
K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,
maskh += item_ct1.get_group_range(1) * nthreads) {
// Calculate KQ tile and keep track of new maximum KQ values:
float KQ_reg[ncols]={}; // KQ in registers.
float KQ_max_new[ncols]={};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max_new[j] = KQ_max[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
const int i_KQ = item_ct1.get_local_id(1) * warp_size +
(nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum<nthreads_KQ>(sum);
if (use_logit_softcap) {
sum = logit_softcap * sycl::tanh(sum);
}
if (mask) {
sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ])
.convert<float, sycl::rounding_mode::automatic>()[0];
}
KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);
if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)
: item_ct1.get_local_id(2) %
nthreads_KQ) == i_KQ_0) {
KQ_reg[j] = sum;
}
}
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {
KQ_max_new[j] = sycl::fmax(
(float)KQ_max_new[j],
(float)dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(),
KQ_max_new[j],
offset,
warp_size));
}
const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));
KQ_max[j] = KQ_max_new[j];
KQ_reg[j] = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#ifdef GGML_SYCL_F16
const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
}
#endif // GGML_SYCL_F16
}
sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {
const int k = item_ct1.get_local_id(1) * warp_size + k0 +
(nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);
#ifdef GGML_SYCL_F16
sycl::half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
sycl::half2 tmp[V_rows_per_thread / 2];
dequantize_V(V + k * nb21, tmp,
2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :
item_ct1.get_local_id(2) % nthreads_V) *
V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
}
}
}
#else
float KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = KQ[j*nthreads + k];
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
sycl::float2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];
}
}
}
#endif // GGML_SYCL_F16
}
}
if (sinks && item_ct1.get_group(1) == 0) {
const float sink = ((const float *) sinks)[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + item_ct1.get_local_id(1);
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
const float kqmax_new_j = sycl::fmax(sink, (float) KQ_max[j]);
const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));
KQ_max[j] = kqmax_new_j;
KQ_sum[j] = KQ_sum[j] * KQ_max_scale +
(item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);
#ifdef GGML_SYCL_F16
const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
}
#endif // GGML_SYCL_F16
}
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (item_ct1.get_local_id(1) == 0) {
KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;
KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (item_ct1.get_local_id(2) == 0) {
KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {
break;
}
float kqmax_new = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));
KQ_max[j_VKQ] = kqmax_new;
#ifdef GGML_SYCL_F16
sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +
(nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);
const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i_VKQ =
i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *
(V_rows_per_thread / 2);
ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ,
&VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);
}
#else
sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)
+ (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);
ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
#endif // GGML_SYCL_F16
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
if (item_ct1.get_local_id(2) == 0) {
KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];
}
item_ct1.barrier(sycl::access::fence_space::local_space);
if (nthreads <= D || tid < D) {
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nthreads) {
float dst_val = 0;
#pragma unroll
for (int w = 0; w < nwarps; ++w) {
#pragma unroll
for (int v = 0; v < V_cols_per_iter; ++v) {
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
}
}
if (item_ct1.get_group_range(1) == 1) {
dst_val /= KQ_sum[j_VKQ];
}
dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +
item_ct1.get_group(1)) *
D +
i0 + tid] = dst_val;
}
}
if (j_VKQ < ncols-1) {
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {
dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +
item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
#endif // SYCL_FLASH_ATTN
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap>
void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE
const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);
const int nwarps = nthreads / warp_size;
const bool need_f16_K = type_K == GGML_TYPE_F16;
const bool need_f16_V = type_V == GGML_TYPE_F16;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1,
flash_attn_ext_vec<D, cols_per_block, type_K, type_V,
use_logit_softcap, warp_size>, warp_size>(
ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, int type_K, int type_V>
void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}
#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
template void ggml_sycl_flash_attn_ext_vec_case \
<D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
#endif // GGML_SYCL_FATTN_VEC_HPP

View File

@ -0,0 +1,225 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "fattn-common.hpp"
#include "fattn-tile.hpp"
#include "fattn-vec.hpp"
#include "fattn.hpp"
#define FATTN_VEC_CASE(D, type_K, type_V) \
{ \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \
} \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \
FATTN_VEC_CASE(128, type_K, type_V) \
FATTN_VEC_CASE(256, type_K, type_V) \
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
#ifdef GGML_SYCL_FA_ALL_QUANTS
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#endif // GGML_SYCL_FA_ALL_QUANTS
GGML_ABORT("Not match KV type in vec");
}
// Best FlashAttention kernel for a specific GPU:
enum best_fattn_kernel {
BEST_FATTN_KERNEL_NONE = 0,
BEST_FATTN_KERNEL_VEC = 100,
BEST_FATTN_KERNEL_TILE = 200,
};
static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
GGML_UNUSED(device);
#ifndef SYCL_FLASH_ATTN
GGML_UNUSED(dst);
return BEST_FATTN_KERNEL_NONE;
#endif// SYCL_FLASH_ATTN
if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr || ggml_is_quantized(t->type)) {
continue;
}
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (t->nb[i] % 16 != 0) {
gqa_opt_applies = false;
break;
}
}
}
switch (K->ne[0]) {
case 40:
case 64:
case 72:
case 80:
case 96:
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 576:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
#ifndef GGML_SYCL_FA_ALL_QUANTS
if (K->type != V->type) {
return BEST_FATTN_KERNEL_NONE;
}
#endif // GGML_SYCL_FA_ALL_QUANTS
switch (K->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
#ifndef GGML_SYCL_FA_ALL_QUANTS
return BEST_FATTN_KERNEL_NONE;
#endif // GGML_SYCL_FA_ALL_QUANTS
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
if (mask && mask->ne[2] != 1) {
return BEST_FATTN_KERNEL_NONE;
}
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// Todo: Use the XMX kernel if possible:
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
return BEST_FATTN_KERNEL_TILE;
}
void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_set_device(ctx.device);
switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
case BEST_FATTN_KERNEL_NONE:
GGML_ABORT("Not support Flash-Attention");
case BEST_FATTN_KERNEL_TILE:
ggml_sycl_flash_attn_ext_tile(ctx, dst);
break;
case BEST_FATTN_KERNEL_VEC:
ggml_sycl_flash_attn_ext_vec(ctx, dst);
break;
}
}
bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
}

View File

@ -0,0 +1,22 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_FATTN_HPP
#define GGML_SYCL_FATTN_HPP
#include "common.hpp"
void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);
#endif // GGML_SYCL_FATTN_HPP

View File

@ -62,6 +62,8 @@ int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
int g_ggml_sycl_use_async_mem_op = 0;
int g_ggml_sycl_enable_flash_attention = 1;
static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {};
@ -94,11 +96,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].nsm = prop.get_max_compute_units();
info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
}
for (int id = 0; id < info.device_count; ++id) {
@ -211,7 +214,37 @@ static void ggml_check_sycl() try {
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
#ifdef SYCL_FLASH_ATTN
g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
#else
g_ggml_sycl_enable_flash_attention = 0;
#endif
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
#endif
#if defined(GGML_SYCL_F16)
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
#endif
#if defined(GGML_SYCL_GRAPH)
GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
#endif
#if defined(GGML_SYCL_DNNL)
GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
#endif
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
@ -226,16 +259,12 @@ static void ggml_check_sycl() try {
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
#ifdef SYCL_FLASH_ATTN
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
#else
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
#endif
#if defined(GGML_SYCL_F16)
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
g_ggml_sycl_enable_flash_attention);
#endif
/* NOT REMOVE, keep it for next optimize for XMX.
@ -3012,7 +3041,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
}
#if GGML_SYCL_DNNL
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
// oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
src1_f16_alloc.alloc(ne_src1);
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
@ -3021,7 +3050,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
# else
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
#endif
@ -4158,6 +4187,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_sycl_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
@ -4862,6 +4894,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(device, op);
default:
return false;
}

View File

@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#define MUL_MAT_SRC1_COL_STRIDE 128
#define QK_WARP_SIZE 32
#define WARP_32_SIZE 32
#define WARP_16_SIZE 16
#endif // GGML_SYCL_PRESETS_HPP

View File

@ -102,7 +102,7 @@ static void soft_max_f32(const float * x,
max_val = sycl::max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
max_val = warp_reduce_max<WARP_SIZE>(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
@ -116,7 +116,7 @@ static void soft_max_f32(const float * x,
item_ct1.barrier();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
max_val = warp_reduce_max<WARP_SIZE>(max_val);
}
float tmp = 0.0f; // partial sum
@ -133,7 +133,7 @@ static void soft_max_f32(const float * x,
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
if (block_size > WARP_SIZE) {
item_ct1.barrier();
if (warp_id == 0) {
@ -153,7 +153,7 @@ static void soft_max_f32(const float * x,
for (size_t i = 1; i < nreduce; i += 1) {
tmp += buf_iw[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp);
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
}
if (sinks) {
tmp += sycl::native::exp(sinks[i02] - max_val);
@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
dgf_dot += dstf[col]*grad[col];
}
dgf_dot = warp_reduce_sum(dgf_dot);
dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(112, 112);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(128, 128);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(256, 256);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(40, 40);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(576, 512);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(64, 64);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(72, 72);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(80, 80);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(96, 96);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.hpp"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);

Some files were not shown because too many files have changed in this diff Show More