Merge branch 'ggml-org:master' into master
This commit is contained in:
commit
0d83366970
|
|
@ -11,6 +11,8 @@
|
|||
/common/base64.hpp.* @ggerganov
|
||||
/common/build-info.* @ggerganov
|
||||
/common/chat.* @pwilkin
|
||||
/common/chat-auto*.* @pwilkin
|
||||
/common/chat-diff-analyzer.* @pwilkin
|
||||
/common/chat-peg-parser.* @aldehir
|
||||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
|
|
@ -89,12 +91,13 @@
|
|||
/src/llama-vocab.* @CISC
|
||||
/src/models/ @CISC
|
||||
/tests/ @ggerganov
|
||||
/tests/test-chat-.* @pwilkin
|
||||
/tests/test-chat.* @pwilkin
|
||||
/tools/batched-bench/ @ggerganov
|
||||
/tools/cli/ @ngxson
|
||||
/tools/completion/ @ggerganov
|
||||
/tools/mtmd/ @ngxson
|
||||
/tools/perplexity/ @ggerganov
|
||||
/tools/parser/ @pwilkin
|
||||
/tools/quantize/ @ggerganov
|
||||
/tools/rpc/ @rgerganov
|
||||
/tools/server/* @ngxson @ggerganov # no subdir
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ add_library(${TARGET} STATIC
|
|||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat-auto-parser-generator.cpp
|
||||
chat-auto-parser-helpers.cpp
|
||||
chat-auto-parser.h
|
||||
chat-diff-analyzer.cpp
|
||||
chat-peg-parser.cpp
|
||||
chat-peg-parser.h
|
||||
chat.cpp
|
||||
|
|
|
|||
|
|
@ -1279,13 +1279,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
|
||||
{"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N",
|
||||
string_format("max number of context checkpoints to create per slot (default: %d)"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
|
||||
[](common_params & params, int value) {
|
||||
params.n_ctx_checkpoints = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"-cpent", "--checkpoint-every-n-tokens"}, "N",
|
||||
string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt),
|
||||
[](common_params & params, int value) {
|
||||
params.checkpoint_every_nt = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"-cram", "--cache-ram"}, "N",
|
||||
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
|
||||
|
|
@ -2827,6 +2834,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.webui_config_json = read_file(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--webui-mcp-proxy"},
|
||||
{"--no-webui-mcp-proxy"},
|
||||
string_format("experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: %s)", params.webui_mcp_proxy ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
{"--webui"},
|
||||
{"--no-webui"},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,413 @@
|
|||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// Helper to iterate over tools/functions
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
fn(tool);
|
||||
}
|
||||
}
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) :
|
||||
p(p),
|
||||
inputs(inputs),
|
||||
reasoning_parser(p.eps()) {}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs) {
|
||||
// Run differential analysis to extract template structure
|
||||
struct autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
return generate_parser(tmpl, inputs, autoparser);
|
||||
}
|
||||
|
||||
common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const autoparser & autoparser) {
|
||||
// Build the parser using the analysis results
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
|
||||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
data.parser = parser.save();
|
||||
|
||||
// Build grammar if tools are present
|
||||
bool has_tools =
|
||||
autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty();
|
||||
std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start :
|
||||
autoparser.tools.format.per_call_start;
|
||||
bool include_grammar =
|
||||
has_tools && ((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED);
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
// Set grammar triggers based on tool section markers (fall back to per-call markers)
|
||||
if (data.grammar_lazy) { // only do triggers on lazy grammar
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
common_peg_arena autoparser::build_parser(const templates_params & inputs) const {
|
||||
if (!analysis_complete) {
|
||||
throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)");
|
||||
}
|
||||
return build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// If the template uses Python dict format (single-quoted strings in JSON structures),
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = inputs.enable_thinking;
|
||||
|
||||
ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
|
||||
bool has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
return ctx.reasoning_parser + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + p.end();
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
return tools.build_parser(ctx);
|
||||
}
|
||||
|
||||
return content.build_parser(ctx);
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
|
||||
if (!ctx.extracting_reasoning) {
|
||||
return p.eps();
|
||||
}
|
||||
|
||||
bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN);
|
||||
bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED);
|
||||
|
||||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
return p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
// Both use the same tag-based pattern if markers are available
|
||||
if (!start.empty() && !end.empty()) {
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end);
|
||||
}
|
||||
} else if (mode == reasoning_mode::DELIMITER) {
|
||||
return p.optional(p.reasoning(p.until(end)) + end);
|
||||
}
|
||||
|
||||
return p.eps();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_content::build_parser(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
|
||||
if (is_always_wrapped()) {
|
||||
if (ctx.extracting_reasoning) {
|
||||
return ctx.reasoning_parser + start + p.content(p.until(end)) + end + p.end();
|
||||
}
|
||||
return p.content(p.until(start)) + start + p.content(p.until(end)) + end + p.end();
|
||||
}
|
||||
return ctx.reasoning_parser + p.content(p.rest()) + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_content::build_optional_wrapped(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
|
||||
if (is_always_wrapped()) {
|
||||
return p.optional(start + p.content(p.until(end)) + end);
|
||||
}
|
||||
return p.eps();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const {
|
||||
switch (format.mode) {
|
||||
case tool_format::JSON_NATIVE:
|
||||
return build_tool_parser_json_native(ctx);
|
||||
case tool_format::TAG_WITH_JSON:
|
||||
return build_tool_parser_tag_json(ctx);
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return build_tool_parser_tag_tagged(ctx);
|
||||
default:
|
||||
GGML_ABORT("Unable to create tool parser");
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
// Build effective field names with dot notation if function_field is set
|
||||
std::string name_field = format.name_field;
|
||||
std::string args_field = format.args_field;
|
||||
|
||||
if (!format.function_field.empty() && format.function_field != "function" &&
|
||||
name_field.find('.') == std::string::npos) {
|
||||
name_field = format.function_field + "." + name_field;
|
||||
args_field = format.function_field + "." + args_field;
|
||||
}
|
||||
|
||||
auto tools_parser = p.standard_json_tools(
|
||||
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
|
||||
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
|
||||
|
||||
// Handle content wrappers if present
|
||||
if (ctx.content && ctx.content->is_always_wrapped()) {
|
||||
auto wrapped_content = ctx.content->build_optional_wrapped(ctx);
|
||||
return ctx.reasoning_parser + wrapped_content + tools_parser + p.end();
|
||||
}
|
||||
|
||||
std::string tool_start = "{";
|
||||
if (!format.section_start.empty()) {
|
||||
tool_start = format.section_start;
|
||||
} else if (!format.per_call_start.empty()) {
|
||||
tool_start = format.per_call_start;
|
||||
}
|
||||
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & schema = func.at("parameters");
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + function.close;
|
||||
}
|
||||
func_parser = p.atomic(func_parser);
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
|
||||
if (!format.per_call_start.empty()) {
|
||||
auto wrapped_call = format.per_call_start + tool_choice + format.per_call_end;
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call);
|
||||
}
|
||||
if (!format.section_start.empty()) {
|
||||
tool_calls = p.trigger_rule("tool-calls",
|
||||
p.literal(format.section_start) + p.space() + tool_calls + p.space() +
|
||||
(format.section_end.empty() ? p.end() : p.literal(format.section_end)));
|
||||
}
|
||||
} else {
|
||||
std::string separator = ", "; // Default
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice +
|
||||
p.zero_or_more(separator + tool_choice) + format.section_end);
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + format.section_end);
|
||||
}
|
||||
}
|
||||
|
||||
if (!require_calls) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.at("parameters");
|
||||
|
||||
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & properties = params.at("properties");
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required") && params.at("required").is_array()) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
// Build parser for each argument
|
||||
std::vector<common_peg_parser> arg_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
std::string type = "object";
|
||||
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
}
|
||||
}
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
|
||||
if (is_required) {
|
||||
arg_parsers.push_back(p.rule("tool-" + name + "-arg-" + param_name, arg));
|
||||
} else {
|
||||
arg_parsers.push_back(p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
|
||||
}
|
||||
}
|
||||
|
||||
// Build arg sequence with space() between consecutive args
|
||||
common_peg_parser args_seq = p.eps();
|
||||
for (size_t i = 0; i < arg_parsers.size(); i++) {
|
||||
if (i > 0) {
|
||||
args_seq = args_seq + p.space();
|
||||
}
|
||||
args_seq = args_seq + arg_parsers[i];
|
||||
}
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser =
|
||||
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
|
||||
func_parser = p.atomic(func_parser);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
|
||||
if (!format.per_call_start.empty()) {
|
||||
auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end;
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call", wrapped_call);
|
||||
}
|
||||
if (!format.section_start.empty()) {
|
||||
tool_calls = p.trigger_rule("tool-calls",
|
||||
p.literal(format.section_start) + p.space() + tool_calls + p.space() +
|
||||
(format.section_end.empty() ? p.end() : p.literal(format.section_end)));
|
||||
}
|
||||
} else {
|
||||
std::string separator = ", "; // Default
|
||||
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call", format.section_start + p.space() + tool_choice +
|
||||
p.zero_or_more(separator + tool_choice) + p.space() +
|
||||
format.section_end);
|
||||
} else {
|
||||
tool_calls = p.trigger_rule(
|
||||
"tool-call", format.section_start + p.space() + tool_choice + p.space() + format.section_end);
|
||||
}
|
||||
}
|
||||
|
||||
if (!require_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cctype>
|
||||
#include <numeric>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
std::string trim_whitespace(const std::string & str) {
|
||||
size_t start = 0;
|
||||
while (start < str.length() && std::isspace(static_cast<unsigned char>(str[start]))) {
|
||||
start++;
|
||||
}
|
||||
|
||||
if (start == str.length()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
size_t end = str.length() - 1;
|
||||
while (end > start && std::isspace(static_cast<unsigned char>(str[end]))) {
|
||||
end--;
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
std::string trim_leading_whitespace(const std::string & str) {
|
||||
size_t start = 0;
|
||||
while (start < str.length() && std::isspace(static_cast<unsigned char>(str[start]))) {
|
||||
start++;
|
||||
}
|
||||
|
||||
return str.substr(start);
|
||||
}
|
||||
|
||||
std::string trim_trailing_whitespace(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
size_t end = str.length() - 1;
|
||||
while (end > 0 && std::isspace(static_cast<unsigned char>(str[end]))) {
|
||||
end--;
|
||||
}
|
||||
|
||||
// If first char is also whitespace, return empty string
|
||||
if (end == 0 && std::isspace(static_cast<unsigned char>(str[0]))) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(0, end + 1);
|
||||
}
|
||||
|
||||
std::string trim_trailing_newlines(const std::string & str) {
|
||||
size_t end = str.length();
|
||||
while (end > 0 && str[end - 1] == '\n') {
|
||||
end--;
|
||||
}
|
||||
|
||||
return str.substr(0, end);
|
||||
}
|
||||
|
||||
static size_t common_prefix_len(const std::string & left, const std::string & right) {
|
||||
size_t prefix_len = 0;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (prefix_len < min_len && left[prefix_len] == right[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return prefix_len;
|
||||
}
|
||||
|
||||
static size_t common_suffix_len(const std::string & left, const std::string & right) {
|
||||
size_t suffix_len = 0;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (suffix_len < min_len && left[left.length() - 1 - suffix_len] == right[right.length() - 1 - suffix_len]) {
|
||||
suffix_len++;
|
||||
}
|
||||
return suffix_len;
|
||||
}
|
||||
|
||||
diff_split calculate_diff_split(const std::string & left, const std::string & right) {
|
||||
diff_split result;
|
||||
|
||||
auto left_seg = segmentize_markers(left);
|
||||
auto right_seg = segmentize_markers(right);
|
||||
|
||||
if (left_seg.empty()) {
|
||||
result.right = right;
|
||||
return result;
|
||||
}
|
||||
if (right_seg.empty()) {
|
||||
result.left = left;
|
||||
return result;
|
||||
}
|
||||
|
||||
auto left_start = left_seg.begin();
|
||||
auto left_end = --left_seg.end();
|
||||
auto right_start = right_seg.begin();
|
||||
auto right_end = --right_seg.end();
|
||||
|
||||
auto test = [&] () {
|
||||
return left_start != left_end && right_start != right_end;
|
||||
};
|
||||
|
||||
bool left_fully_consumed = false;
|
||||
bool right_fully_consumed = false;
|
||||
|
||||
while (test()) {
|
||||
bool advanced = false;
|
||||
if (*left_start == *right_start) {
|
||||
result.prefix.append(left_start->value);
|
||||
left_start++;
|
||||
right_start++;
|
||||
advanced = true;
|
||||
}
|
||||
if (*left_end == *right_end) {
|
||||
result.suffix = left_end->value + result.suffix;
|
||||
if (left_start != left_end) {
|
||||
left_end--;
|
||||
} else {
|
||||
left_fully_consumed = true;
|
||||
}
|
||||
if (right_start != right_end) {
|
||||
right_end--;
|
||||
} else {
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
advanced = true;
|
||||
}
|
||||
if (!advanced) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (left_start == left_end && right_start != right_end) {
|
||||
if (*left_start == *right_end) {
|
||||
result.suffix = right_end->value + result.suffix;
|
||||
right_end--;
|
||||
left_fully_consumed = true;
|
||||
} else if (*left_start == *right_start) {
|
||||
result.prefix.append(right_start->value);
|
||||
right_start++;
|
||||
left_fully_consumed = true;
|
||||
}
|
||||
} else if (right_start == right_end && left_start != left_end) {
|
||||
if (*left_end == *right_start) {
|
||||
result.suffix = left_end->value + result.suffix;
|
||||
left_end--;
|
||||
right_fully_consumed = true;
|
||||
} else if (*left_start == *right_start) {
|
||||
result.prefix.append(left_start->value);
|
||||
left_start++;
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
} else if (left_start == left_end && right_start == right_end && *left_start == *right_start && left_start->type == segment_type::MARKER) {
|
||||
result.prefix.append(right_start->value);
|
||||
left_fully_consumed = true;
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
|
||||
auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); };
|
||||
|
||||
bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT;
|
||||
bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT;
|
||||
|
||||
std::string remainder_left = std::accumulate(left_start, left_fully_consumed ? left_end : ++left_end, std::string(), eat_segment);
|
||||
std::string remainder_right = std::accumulate(right_start, right_fully_consumed ? right_end : ++right_end, std::string(), eat_segment);
|
||||
|
||||
size_t suffix_len = can_have_text_suffix ? common_suffix_len(remainder_left, remainder_right) : 0;
|
||||
// avoid overlaps between prefix and suffix
|
||||
size_t prefix_len = can_have_text_prefix ? common_prefix_len(remainder_left.substr(0, remainder_left.size() - suffix_len),
|
||||
remainder_right.substr(0, remainder_right.size() - suffix_len)) : 0;
|
||||
|
||||
result.prefix.append(remainder_left.substr(0, prefix_len));
|
||||
result.suffix = remainder_left.substr(remainder_left.length() - suffix_len, suffix_len) + result.suffix;
|
||||
result.left = remainder_left.substr(prefix_len, remainder_left.length() - prefix_len - suffix_len);
|
||||
result.right = remainder_right.substr(prefix_len, remainder_right.length() - prefix_len - suffix_len);
|
||||
|
||||
if (result.left == "" && result.right == "") {
|
||||
// degenerate case, no diff
|
||||
result.prefix = left;
|
||||
result.suffix = "";
|
||||
// pick prefix = all as representation
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right`
|
||||
std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right) {
|
||||
// Find the common prefix of left and right
|
||||
size_t common_prefix_len = 0;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (common_prefix_len < min_len && left[common_prefix_len] == right[common_prefix_len]) {
|
||||
common_prefix_len++;
|
||||
}
|
||||
|
||||
// If there's no common prefix, return empty string
|
||||
if (common_prefix_len == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find the common prefix in the full string
|
||||
std::string common_prefix = left.substr(0, common_prefix_len);
|
||||
size_t pos = full.find(common_prefix);
|
||||
|
||||
// If not found, return empty string
|
||||
if (pos == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Return everything before the common prefix
|
||||
return full.substr(0, pos);
|
||||
}
|
||||
|
||||
// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right`
|
||||
std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right) {
|
||||
// Find the common suffix of left and right (compare from the end)
|
||||
size_t common_suffix_len = 0;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (common_suffix_len < min_len &&
|
||||
left[left.length() - 1 - common_suffix_len] == right[right.length() - 1 - common_suffix_len]) {
|
||||
common_suffix_len++;
|
||||
}
|
||||
|
||||
// If there's no common suffix, return empty string
|
||||
if (common_suffix_len == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Extract the common suffix
|
||||
std::string common_suffix = left.substr(left.length() - common_suffix_len);
|
||||
|
||||
// Find the last occurrence of the common suffix in the full string
|
||||
size_t pos = full.rfind(common_suffix);
|
||||
|
||||
// If not found, return empty string
|
||||
if (pos == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Return everything after the common suffix
|
||||
return full.substr(pos + common_suffix_len);
|
||||
}
|
||||
|
||||
// TODO: segmentize will treat a JSON array inside tags as a tag: <calls>[{ "fun": { ... } }]</calls> will be three markers
|
||||
// not too worried about that because it hasn't turned out as a problem anywhere, but noting here in case it will
|
||||
// Might have to put some restrictions on tag contents as well (like "no { }")
|
||||
std::vector<segment> segmentize_markers(const std::string & text) {
|
||||
std::vector<segment> retval;
|
||||
bool in_marker = false;
|
||||
char marker_opener = '\0';
|
||||
|
||||
auto is_marker_opener = [](char c) -> bool { return c == '<' || c == '['; };
|
||||
auto is_marker_closer = [](char op, char c) -> bool { return (op == '<' && c == '>') || (op == '[' && c == ']'); };
|
||||
|
||||
size_t last_border = 0;
|
||||
|
||||
for (size_t cur_pos = 0; cur_pos < text.length(); cur_pos++) {
|
||||
if (!in_marker && is_marker_opener(text[cur_pos])) {
|
||||
if (last_border < cur_pos) {
|
||||
retval.push_back(segment(segment_type::TEXT, text.substr(last_border, cur_pos - last_border)));
|
||||
}
|
||||
last_border = cur_pos;
|
||||
in_marker = true;
|
||||
marker_opener = text[cur_pos];
|
||||
} else if (in_marker && is_marker_closer(marker_opener, text[cur_pos])) {
|
||||
// no need to check because last_border will always be smaller
|
||||
retval.push_back(segment(segment_type::MARKER, text.substr(last_border, cur_pos - last_border + 1)));
|
||||
last_border = cur_pos + 1;
|
||||
in_marker = false;
|
||||
marker_opener = '\0';
|
||||
}
|
||||
}
|
||||
if (last_border < text.length()) {
|
||||
retval.push_back(segment(segment_type::TEXT, text.substr(last_border)));
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments) {
|
||||
std::vector<segment> result;
|
||||
for (const auto & seg : segments) {
|
||||
if (!trim_whitespace(seg.value).empty()) {
|
||||
result.push_back(seg);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
|
||||
templates_params tmpl_params;
|
||||
tmpl_params.messages = params.messages;
|
||||
tmpl_params.tools = params.tools;
|
||||
tmpl_params.add_generation_prompt = params.add_generation_prompt;
|
||||
tmpl_params.enable_thinking = params.enable_thinking;
|
||||
|
||||
if (params.extra_context) {
|
||||
tmpl_params.extra_context = *params.extra_context;
|
||||
}
|
||||
tmpl_params.extra_context["enable_thinking"] = params.enable_thinking;
|
||||
|
||||
try {
|
||||
return common_chat_template_direct_apply(tmpl, tmpl_params);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_DBG("Template application failed: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<compare_variants_result> compare_variants(
|
||||
const common_chat_template & tmpl,
|
||||
const template_params & params_A,
|
||||
const std::function<void(template_params &)> & params_modifier) {
|
||||
// Create variant B by copying A
|
||||
template_params params_B = params_A;
|
||||
|
||||
// Apply modifier to create variant B
|
||||
if (params_modifier) {
|
||||
params_modifier(params_B);
|
||||
}
|
||||
|
||||
// Apply template to both variants
|
||||
std::string output_A = apply_template(tmpl, params_A);
|
||||
std::string output_B = apply_template(tmpl, params_B);
|
||||
|
||||
// Check for template application failures
|
||||
if (output_A.empty() || output_B.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Calculate diff and return result with both outputs
|
||||
compare_variants_result result;
|
||||
result.diff = calculate_diff_split(output_A, output_B);
|
||||
result.output_A = output_A;
|
||||
result.output_B = output_B;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
std::string trim_whitespace(const std::string & str);
|
||||
std::string trim_leading_whitespace(const std::string & str);
|
||||
std::string trim_trailing_whitespace(const std::string & str);
|
||||
std::string trim_trailing_newlines(const std::string & str);
|
||||
|
||||
// calculate a diff split (longest common prefix, longest common suffix excluding prefix,
|
||||
// mismatched part on the left, mismatched part on the right) between two strings
|
||||
// account for markers - align prefix and suffix endings so that they end on markers
|
||||
// * eg.:
|
||||
// calculate_diff_split("<html><body><div></div></body></html>", "<html><body><p>Something</p></body><html>") ->
|
||||
// { "prefix": "<html><body>" (not: "<html><body><"), "suffix": "</body></html>", "left": "<div></div>", "right": "<p>Something</p>" }
|
||||
// calculate_diff_split("<html><body>Something</body></html>", "<html><body></body><html>") ->
|
||||
// { "prefix": "<html><body>", "suffix": "</body></html>", "left": "Something", "right": "" }
|
||||
diff_split calculate_diff_split(const std::string & left, const std::string & right);
|
||||
|
||||
// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right`
|
||||
// Returns empty string if there's no common prefix
|
||||
// * eg.:
|
||||
// until_common_prefix("really want a FUNCTION call", "FUNCTION alpha", "FUNCTION beta") -> "really want a "
|
||||
// until_common_prefix("<tool_call>", "<something>", "<something_else>") -> ""
|
||||
// until_common_prefix("some text", "1234", "abcd") -> ""
|
||||
// until_common_prefix("one arg two args three args four", "argument alpha", "argument beta") -> "one ""
|
||||
std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right);
|
||||
|
||||
// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right`
|
||||
// Returns empty string if there's no common suffix
|
||||
// Mirror function of `until_common_prefix`
|
||||
// * eg.:
|
||||
// after_common_suffix("really want a FUNCTION call", "first FUNCTION", "second FUNCTION") -> " call"
|
||||
// after_common_suffix("one arg two-args three args four", "alpha-args", "beta-args") -> " three args four"
|
||||
std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right);
|
||||
|
||||
// Segmentize text into markers and non-marker fragments
|
||||
// * eg.:
|
||||
// segmentize_markers("<html><head><title>The site title</title><body><div>Here's some <b>content</b></div></body></html>" ->
|
||||
// [ (MARKER, "<html>"), (MARKER, "<head>"), (MARKER, "<title>"), (TEXT, "The site title"), (MARKER, "</title>"),
|
||||
// (MARKER, "<body>"), (MARKER, "<div>"), (TEXT, "Here's some "), (MARKER, "<b>"), (TEXT, "content"), (MARKER, "</b>"),
|
||||
// (MARKER, "</div>"), (MARKER, "</body>"), (MARKER, "</html>")
|
||||
// ]
|
||||
// segmentize_markers("<|tool_call|>[args]{ are here }[/args]<|tool_call_end|>") ->
|
||||
// [ (MARKER, "<|tool_call|>"), (MARKER, "[args]"), (TEXT, "{ are here }"), (MARKER, "[/args]"), (MARKER, "<|tool_call_end|>") ]
|
||||
std::vector<segment> segmentize_markers(const std::string & text);
|
||||
|
||||
// Prune whitespace-only segments from a vector of segments
|
||||
// * eg.:
|
||||
// segmentize_markers("<tool_call>\n<function=foo>\n<arg=bar>\n \n</arg>\n</function>\n</tool_call>") ->
|
||||
// X = [ (MARKER, "<tool_call>"), (TEXT, "\n"), (MARKER, "<function=foo>"), (TEXT, "\n"), (MARKER, "<arg=bar>"), (TEXT, "\n \n"),
|
||||
// (MARKER, "</arg>"), (TEXT, "\n"), (MARKER, "</function>"), (TEXT, "\n"), (MARKER, "</tool_call>") ]
|
||||
// prune_whitespace_segments(X) -> [ (MARKER, "<tool_call>"), (MARKER, "<function=foo>"), (MARKER, "<arg=bar>"), (MARKER, "</arg>"),
|
||||
// (MARKER, "</function>"), (MARKER, "</tool_call>") ]
|
||||
std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segments);
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
// Apply a template with the given parameters, returning the rendered string (empty on failure)
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params);
|
||||
|
||||
// Factorized differential comparison function
|
||||
// Takes base params and a single modifier lambda to create variant B
|
||||
// Returns compare_variants_result containing diff and both outputs, or std::nullopt on failure
|
||||
std::optional<compare_variants_result> compare_variants(
|
||||
const common_chat_template & tmpl,
|
||||
const template_params & params_A,
|
||||
const std::function<void(template_params &)> & params_modifier);
|
||||
|
||||
} // namespace autoparser
|
||||
|
|
@ -0,0 +1,433 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
class common_chat_peg_builder;
|
||||
|
||||
// ============================================================================
|
||||
// Parameters for template application (low-level, used by diff analysis)
|
||||
// ============================================================================
|
||||
struct template_params {
|
||||
json messages;
|
||||
json tools;
|
||||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::optional<json> extra_context = std::nullopt;
|
||||
};
|
||||
|
||||
struct diff_split {
|
||||
std::string prefix;
|
||||
std::string suffix;
|
||||
std::string left;
|
||||
std::string right;
|
||||
|
||||
bool operator==(struct diff_split & other) const {
|
||||
return prefix == other.prefix && suffix == other.suffix && left == other.left && right == other.right;
|
||||
}
|
||||
};
|
||||
|
||||
// Result of compare_variants containing diff and original outputs
|
||||
struct compare_variants_result {
|
||||
diff_split diff;
|
||||
std::string output_A;
|
||||
std::string output_B;
|
||||
};
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
// ============================================================================
|
||||
// High-level params for parser generation
|
||||
// ============================================================================
|
||||
|
||||
struct templates_params {
|
||||
json messages;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
bool stream = true;
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = false;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
json extra_context;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
bool is_inference = true;
|
||||
bool add_inference = false;
|
||||
bool mark_input = true; // whether to mark input strings in the jinja context
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Analysis Result Enums
|
||||
// ============================================================================
|
||||
|
||||
// Reasoning handling mode (derived from R1-R3 comparisons)
|
||||
enum class reasoning_mode {
|
||||
NONE, // No reasoning markers detected
|
||||
TAG_BASED, // Standard tag-based: <think>...</think>
|
||||
DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter)
|
||||
FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end)
|
||||
FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but
|
||||
// with both opened and closed tag for disabled thinking
|
||||
TOOLS_ONLY // Only reason on tool calls, not on normal content
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode) {
|
||||
switch (mode) {
|
||||
case reasoning_mode::NONE:
|
||||
return os << "NONE";
|
||||
case reasoning_mode::TAG_BASED:
|
||||
return os << "TAG_BASED";
|
||||
case reasoning_mode::DELIMITER:
|
||||
return os << "DELIMITER";
|
||||
case reasoning_mode::FORCED_OPEN:
|
||||
return os << "FORCED_OPEN";
|
||||
case reasoning_mode::FORCED_CLOSED:
|
||||
return os << "FORCED_CLOSED";
|
||||
case reasoning_mode::TOOLS_ONLY:
|
||||
return os << "TOOLS_ONLY";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
// Content wrapping mode (derived from C1 comparison)
|
||||
enum class content_mode {
|
||||
PLAIN, // No content markers
|
||||
ALWAYS_WRAPPED, // Content always wrapped with markers
|
||||
WRAPPED_WITH_REASONING, // Content wrapped only when reasoning present
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const content_mode & mode) {
|
||||
switch (mode) {
|
||||
case content_mode::PLAIN:
|
||||
return os << "PLAIN";
|
||||
case content_mode::ALWAYS_WRAPPED:
|
||||
return os << "ALWAYS_WRAPPED";
|
||||
case content_mode::WRAPPED_WITH_REASONING:
|
||||
return os << "WRAPPED_WITH_REASONING";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
// Call ID position in tool calls (for non-JSON formats)
|
||||
enum class call_id_position {
|
||||
NONE, // No call ID support detected
|
||||
PRE_FUNC_NAME, // Call ID before function name: [CALL_ID]id[FUNC]name{args}
|
||||
BETWEEN_FUNC_AND_ARGS, // Call ID between function and args: [FUNC]name[CALL_ID]id{args}
|
||||
POST_ARGS, // Call ID after arguments: [FUNC]name{args}[CALL_ID]id
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const call_id_position & pos) {
|
||||
switch (pos) {
|
||||
case call_id_position::NONE:
|
||||
return os << "NONE";
|
||||
case call_id_position::PRE_FUNC_NAME:
|
||||
return os << "PRE_FUNC_NAME";
|
||||
case call_id_position::BETWEEN_FUNC_AND_ARGS:
|
||||
return os << "BETWEEN_FUNC_AND_ARGS";
|
||||
case call_id_position::POST_ARGS:
|
||||
return os << "POST_ARGS";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call format classification (derived from T1-T5, A1-A3 comparisons)
|
||||
enum class tool_format {
|
||||
NONE, // No tool support detected
|
||||
JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}}
|
||||
TAG_WITH_JSON, // Tag-based with JSON args: <function=X>{...}</function>
|
||||
TAG_WITH_TAGGED, // Tag-based with tagged args: <param=key>value</param>
|
||||
};
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const tool_format & format) {
|
||||
switch (format) {
|
||||
case tool_format::NONE:
|
||||
return os << "NONE";
|
||||
case tool_format::JSON_NATIVE:
|
||||
return os << "JSON_NATIVE";
|
||||
case tool_format::TAG_WITH_JSON:
|
||||
return os << "TAG_WITH_JSON";
|
||||
case tool_format::TAG_WITH_TAGGED:
|
||||
return os << "TAG_WITH_TAGGED";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Sub-structs for tool analysis
|
||||
// ============================================================================
|
||||
|
||||
struct tool_format_analysis {
|
||||
tool_format mode = tool_format::NONE;
|
||||
|
||||
std::string section_start; // e.g., "<tool_call>", "[TOOL_CALLS]", ""
|
||||
std::string section_end; // e.g., "</tool_call>", ""
|
||||
std::string per_call_start; // e.g., "<|tool_call_begin|>", "" (for multi-call templates)
|
||||
std::string per_call_end; // e.g., "<|tool_call_end|>", ""
|
||||
|
||||
bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "<funname>": { ... arguments ... } }
|
||||
bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...]
|
||||
bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings)
|
||||
|
||||
std::string function_field = "function";
|
||||
std::string name_field = "name";
|
||||
std::string args_field = "arguments";
|
||||
std::string id_field;
|
||||
std::string gen_id_field;
|
||||
std::vector<std::string> parameter_order;
|
||||
};
|
||||
|
||||
struct tool_function_analysis {
|
||||
std::string name_prefix; // e.g., "<function=", "\"name\": \"", "functions."
|
||||
std::string name_suffix; // e.g., ">", "\"", ":0"
|
||||
std::string close; // e.g., "</function>", "" (for tag-based)
|
||||
};
|
||||
|
||||
struct tool_arguments_analysis {
|
||||
std::string start; // e.g., "<|tool_call_argument_begin|>", "<args>"
|
||||
std::string end; // e.g., "<|tool_call_argument_end|>", "</args>"
|
||||
std::string name_prefix; // e.g., "<param=", "<arg_key>", "\""
|
||||
std::string name_suffix; // e.g., ">", "</arg_key>", "\":"
|
||||
std::string value_prefix; // e.g., "", "<arg_value>", ""
|
||||
std::string value_suffix; // e.g., "</param>", "</arg_value>", ""
|
||||
std::string separator; // e.g., "", "\n", ","
|
||||
};
|
||||
|
||||
struct tool_id_analysis {
|
||||
call_id_position pos = call_id_position::NONE;
|
||||
|
||||
std::string prefix; // e.g., "[CALL_ID]" (marker before call ID value)
|
||||
std::string suffix; // e.g., "" (marker after call ID value, before next section)
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Parser build context (shared interface for build_parser methods)
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_content;
|
||||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const templates_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const templates_params & inputs);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Base class for analyzers with parser building
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_base {
|
||||
virtual ~analyze_base() = default;
|
||||
virtual common_peg_parser build_parser(parser_build_context & ctx) const = 0;
|
||||
|
||||
protected:
|
||||
const common_chat_template * tmpl = nullptr;
|
||||
|
||||
analyze_base() = default;
|
||||
explicit analyze_base(const common_chat_template & tmpl) : tmpl(&tmpl) {}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Reasoning analyzer
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_reasoning : analyze_base {
|
||||
reasoning_mode mode = reasoning_mode::NONE;
|
||||
|
||||
std::string start; // e.g., "<think>", "[THINK]", "<|START_THINKING|>", ""
|
||||
std::string end; // e.g., "</think>", "[BEGIN FINAL RESPONSE]", "<|END_THINKING|>"
|
||||
|
||||
analyze_reasoning() = default;
|
||||
analyze_reasoning(const common_chat_template & tmpl, bool supports_tools);
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
private:
|
||||
// Look for reasoning markers in rendered content
|
||||
void compare_reasoning_presence();
|
||||
|
||||
// Compare generation prompt with enable_thinking=true vs false
|
||||
void compare_thinking_enabled();
|
||||
|
||||
// Check if reasoning is always possible or only in tool calls
|
||||
void compare_reasoning_scope();
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Content analyzer
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_content : analyze_base {
|
||||
content_mode mode = content_mode::PLAIN;
|
||||
|
||||
std::string start; // e.g., "<response>", ">>>all\n", ""
|
||||
std::string end; // e.g., "</response>", ""
|
||||
|
||||
bool requires_nonnull_content = false;
|
||||
|
||||
analyze_content() = default;
|
||||
analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning);
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
bool is_always_wrapped() const;
|
||||
common_peg_parser build_optional_wrapped(parser_build_context & ctx) const;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Tool analyzer
|
||||
// ============================================================================
|
||||
|
||||
struct analyze_tools : analyze_base {
|
||||
tool_format_analysis format;
|
||||
tool_function_analysis function;
|
||||
tool_arguments_analysis arguments;
|
||||
tool_id_analysis call_id;
|
||||
|
||||
analyze_tools() = default;
|
||||
analyze_tools(const common_chat_template & tmpl,
|
||||
const jinja::caps & caps,
|
||||
const analyze_reasoning & reasoning);
|
||||
|
||||
common_peg_parser build_parser(parser_build_context & ctx) const override;
|
||||
|
||||
private:
|
||||
// Extract tool calling 'haystack' for further analysis and delegate further analysis based on format
|
||||
void analyze_tool_calls(const analyze_reasoning & reasoning);
|
||||
|
||||
// Analyze format based on position of function and argument name in needle
|
||||
void analyze_tool_call_format(const std::string & haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle,
|
||||
const analyze_reasoning & reasoning);
|
||||
|
||||
// Analyze specifics of JSON native format (entire tool call is a JSON object)
|
||||
void analyze_tool_call_format_json_native(const std::string & clean_haystack,
|
||||
const std::string & fun_name_needle,
|
||||
const std::string & arg_name_needle);
|
||||
|
||||
// Analyze specifics of non-JSON native format (tags for function name or for function name and arguments)
|
||||
void analyze_tool_call_format_non_json(const std::string & clean_haystack,
|
||||
const std::string & fun_name_needle);
|
||||
|
||||
// Check for and extract specific per-call markers for non-native-JSON templates with parallel call support
|
||||
void check_per_call_markers();
|
||||
|
||||
// Extract function name markers
|
||||
void extract_function_markers();
|
||||
|
||||
// Delegates to separate functions for: separator analysis, argument name analysis, argument value analysis
|
||||
void analyze_arguments();
|
||||
|
||||
// Extract argument name markers
|
||||
void extract_argument_name_markers();
|
||||
|
||||
// Extract argument value markers
|
||||
void extract_argument_value_markers();
|
||||
|
||||
// Extract argument separator, if specified (eg. <arg=foo>...</arg><sep><arg=bar>...</arg>)
|
||||
void extract_argument_separator();
|
||||
|
||||
// Extract argument wrapper markers, if present (eg. '<args><arg=foo>...</arg><arg=bar>...</arg></args>')
|
||||
void extract_args_markers();
|
||||
|
||||
// Extract call ID markers, if present
|
||||
void extract_call_id_markers();
|
||||
|
||||
// Per-format tool parser builders
|
||||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Main autoparser class
|
||||
// ============================================================================
|
||||
|
||||
struct autoparser {
|
||||
jinja::caps jinja_caps;
|
||||
analyze_reasoning reasoning;
|
||||
analyze_content content;
|
||||
analyze_tools tools;
|
||||
bool analysis_complete = false;
|
||||
|
||||
// Preserved tokens for tokenizer (union of all non-empty markers)
|
||||
std::vector<std::string> preserved_tokens;
|
||||
|
||||
autoparser() = default;
|
||||
|
||||
// Run full differential analysis on a template
|
||||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
// Build the PEG parser for this template
|
||||
common_peg_arena build_parser(const templates_params & inputs) const;
|
||||
|
||||
private:
|
||||
// Collect tokens from entire analysis to preserve
|
||||
void collect_preserved_tokens();
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Parser generator
|
||||
// ============================================================================
|
||||
|
||||
class peg_generator {
|
||||
public:
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs);
|
||||
|
||||
static common_chat_params generate_parser(const common_chat_template & tmpl,
|
||||
const struct templates_params & inputs,
|
||||
const autoparser & autoparser);
|
||||
};
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
enum segment_type { TEXT, MARKER };
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, const segment_type & type) {
|
||||
switch (type) {
|
||||
case segment_type::TEXT:
|
||||
return os << "TEXT";
|
||||
case segment_type::MARKER:
|
||||
return os << "MARKER";
|
||||
default:
|
||||
return os << "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
struct segment {
|
||||
segment_type type;
|
||||
std::string value;
|
||||
|
||||
segment(segment_type type, std::string value) : type(type), value(std::move(value)) {}
|
||||
|
||||
bool operator==(const segment & other) const {
|
||||
return type == other.type && value == other.value;
|
||||
}
|
||||
|
||||
bool operator!=(const segment & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,879 +0,0 @@
|
|||
#include "chat.h"
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||
public:
|
||||
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline void sort_uniq(std::vector<T> &vec) {
|
||||
std::sort(vec.begin(), vec.end());
|
||||
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool all_space(const T &str) {
|
||||
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||
}
|
||||
|
||||
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||
size_t len = s.size();
|
||||
if (len == 0) return 0;
|
||||
size_t i = len;
|
||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||
--i;
|
||||
unsigned char c = s[i];
|
||||
if ((c & 0x80) == 0) {
|
||||
return len;
|
||||
} else if ((c & 0xC0) == 0xC0) {
|
||||
size_t expected_len = 0;
|
||||
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||
else return i;
|
||||
if (len - i >= expected_len) {
|
||||
return len;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return len - std::min(len, size_t(3));
|
||||
}
|
||||
|
||||
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||
s.resize(utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||
return s.substr(0, utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||
const auto saved_pos = builder.pos();
|
||||
while (auto res = builder.try_find_literal(literal1)) {
|
||||
builder.consume_spaces();
|
||||
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||
}
|
||||
builder.move_to(builder.pos() + match_len);
|
||||
res->groups[0].end = builder.pos();
|
||||
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||
return res;
|
||||
}
|
||||
builder.move_to(res->groups[0].begin + 1);
|
||||
}
|
||||
builder.move_to(saved_pos);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
*/
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||
std::string s = "\\";
|
||||
s.push_back((char)c);
|
||||
return s;
|
||||
}
|
||||
if (isprint(c)) {
|
||||
return std::string(1, (char)c);
|
||||
}
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", c);
|
||||
return std::string(buf);
|
||||
};
|
||||
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||
int i = l;
|
||||
while (i < r) {
|
||||
const std::string &s = forbids[i];
|
||||
if ((int)s.size() == depth) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
unsigned char c = (unsigned char)s[depth];
|
||||
int j = i;
|
||||
while (j < r && (int)forbids[j].size() > depth &&
|
||||
(unsigned char)forbids[j][depth] == c) {
|
||||
++j;
|
||||
}
|
||||
children.push_back({c, {i, j}});
|
||||
i = j;
|
||||
}
|
||||
std::vector<std::string> alts;
|
||||
if (!children.empty()) {
|
||||
std::string cls;
|
||||
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||
alts.push_back(std::string("[^") + cls + "]");
|
||||
}
|
||||
for (auto &ch : children) {
|
||||
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||
if (!childExpr.empty()) {
|
||||
std::string quoted_ch = "\"";
|
||||
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||
else {
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||
quoted_ch += buf;
|
||||
}
|
||||
quoted_ch += "\"";
|
||||
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||
alts.push_back(branch);
|
||||
}
|
||||
}
|
||||
if (alts.empty()) return "";
|
||||
std::ostringstream oss;
|
||||
oss << "( ";
|
||||
for (size_t k = 0; k < alts.size(); ++k) {
|
||||
if (k) oss << " | ";
|
||||
oss << alts[k];
|
||||
}
|
||||
oss << " )";
|
||||
return oss.str();
|
||||
};
|
||||
if (forbids.empty()) return "( . )*";
|
||||
sort(forbids.begin(), forbids.end());
|
||||
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||
if (expr.empty()) {
|
||||
std::string cls;
|
||||
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||
expr = std::string("( [^") + cls + "] )";
|
||||
}
|
||||
if (forbids.size() == 1)
|
||||
return expr + "*";
|
||||
else
|
||||
return std::string("( ") + expr + " )*";
|
||||
}
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.tool_sep.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
std::string key_val_sep = form.key_val_sep;
|
||||
if (form.key_val_sep2) {
|
||||
key_val_sep += "\n";
|
||||
key_val_sep += *form.key_val_sep2;
|
||||
}
|
||||
GGML_ASSERT(!key_val_sep.empty());
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||
auto string_arg_val = form.last_val_end ?
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool.at("function");
|
||||
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
struct parameter_rule {
|
||||
std::string symbol_name;
|
||||
bool is_required;
|
||||
};
|
||||
std::vector<parameter_rule> arg_rules;
|
||||
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
} else {
|
||||
std::vector<std::string> requiredParameters;
|
||||
if (parameters.contains("required")) {
|
||||
try { parameters.at("required").get_to(requiredParameters); }
|
||||
catch (const std::runtime_error&) {
|
||||
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||
}
|
||||
}
|
||||
sort_uniq(requiredParameters);
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
std::string quoted_key = key;
|
||||
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||
quoted_key = gbnf_format_literal(key);
|
||||
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||
}
|
||||
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||
gbnf_format_literal(form.key_start) + " " +
|
||||
gbnf_format_literal(quoted_key) + " " +
|
||||
gbnf_format_literal(key_val_sep) + " " +
|
||||
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||
(form.raw_argval ?
|
||||
string_arg_val :
|
||||
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||
) :
|
||||
builder.add_schema(name + "-arg-" + key, value)
|
||||
)
|
||||
), required});
|
||||
}
|
||||
}
|
||||
|
||||
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||
);
|
||||
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||
);
|
||||
}
|
||||
|
||||
std::string quoted_name = name;
|
||||
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||
quoted_name = gbnf_format_literal(name);
|
||||
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||
}
|
||||
quoted_name = gbnf_format_literal(quoted_name);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
gbnf_format_literal(form.tool_start) + " " +
|
||||
quoted_name + " " +
|
||||
gbnf_format_literal(form.tool_sep) + " " +
|
||||
next_arg
|
||||
));
|
||||
}
|
||||
|
||||
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||
builder.add_rule("root",
|
||||
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||
tool_call_multiple_with_end + "?" +
|
||||
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||
);
|
||||
});
|
||||
|
||||
// grammar trigger for tool call
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.key_val_sep.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
// Helper to choose return false or throw error
|
||||
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||
if (recovery) {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||
};
|
||||
// Drop substring from needle to end from a JSON
|
||||
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||
auto pos = json_str.rfind(needle);
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||
--pos;
|
||||
}
|
||||
json_str.resize(pos);
|
||||
return true;
|
||||
};
|
||||
// Helper to generate a partial argument JSON
|
||||
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||
auto rest = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(rest);
|
||||
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||
auto tool_str = arguments.dump();
|
||||
if (partial_json(tool_str)) {
|
||||
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||
};
|
||||
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||
constexpr auto try_find_close = [](
|
||||
common_chat_msg_parser & builder,
|
||||
const std::string & end,
|
||||
const std::optional<std::string> & alt_end,
|
||||
const std::string & end_next,
|
||||
const std::optional<std::string> & alt_end_next
|
||||
) {
|
||||
auto saved_pos = builder.pos();
|
||||
auto tc = builder.try_find_literal(end);
|
||||
auto val_end_size = end.size();
|
||||
if (alt_end) {
|
||||
auto pos_1 = builder.pos();
|
||||
builder.move_to(saved_pos);
|
||||
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||
if (alt_end_next) {
|
||||
builder.move_to(saved_pos);
|
||||
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||
tc2 = tc3;
|
||||
}
|
||||
}
|
||||
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||
tc = tc2;
|
||||
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||
builder.move_to(tc->groups[0].end);
|
||||
val_end_size = alt_end->size();
|
||||
} else {
|
||||
builder.move_to(pos_1);
|
||||
}
|
||||
}
|
||||
return std::make_pair(val_end_size, tc);
|
||||
};
|
||||
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||
};
|
||||
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||
};
|
||||
|
||||
bool recovery = true;
|
||||
const auto start_pos = builder.pos();
|
||||
if (!all_space(form.scope_start)) {
|
||||
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||
if (all_space(tc->prelude)) {
|
||||
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.tool_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
|
||||
// Find tool name
|
||||
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||
if (!func_name) {
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
if (!func_name) {
|
||||
// Partial tool name not supported
|
||||
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||
}
|
||||
// If the model generate multiple tool call and the first tool call has no argument
|
||||
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
|
||||
// Parse tool name
|
||||
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||
std::string function_name = string_strip(func_name->prelude);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
if (string_starts_with(function_name, "functions.")) {
|
||||
static const std::regex re(":\\d+$");
|
||||
if (std::regex_search(function_name, re)) {
|
||||
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Argument JSON
|
||||
json arguments = json::object();
|
||||
|
||||
// Helper to generate a partial argument JSON
|
||||
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||
};
|
||||
|
||||
// Parse all arg_key/arg_value pairs
|
||||
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.key_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
|
||||
// Parse arg_key
|
||||
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||
if (!key_res) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
auto &key = key_res->prelude;
|
||||
recovery = false;
|
||||
|
||||
// Parse arg_value
|
||||
if (form.key_val_sep2) {
|
||||
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||
gbnf_format_literal(tc->prelude).c_str(),
|
||||
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, false);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||
}
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
}
|
||||
auto val_start = builder.pos();
|
||||
|
||||
// Test if arg_val is a partial JSON
|
||||
std::optional<common_json> value_json = std::nullopt;
|
||||
if (!form.raw_argval || !*form.raw_argval) {
|
||||
try { value_json = builder.try_consume_json(); }
|
||||
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||
if (builder.pos() == val_start) {
|
||||
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||
builder.consume_spaces();
|
||||
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||
sv.remove_prefix(builder.pos());
|
||||
std::string rest = "a";
|
||||
if (sv.size() < 6) rest = sv;
|
||||
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||
value_json = {123, {"123", "123"}};
|
||||
builder.consume_rest();
|
||||
} else {
|
||||
builder.move_to(val_start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If it is a JSON and followed by </arg_value>, parse as json
|
||||
// cannot support streaming because it may be a plain text starting with JSON
|
||||
if (value_json) {
|
||||
auto json_end = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||
arguments[key] = value_json->json;
|
||||
auto json_str = arguments.dump();
|
||||
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
} else {
|
||||
GGML_ASSERT(json_str.back() == '}');
|
||||
json_str.resize(json_str.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", json_str);
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
}
|
||||
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||
}
|
||||
builder.move_to(json_end);
|
||||
auto [val_end_size, tc] = try_find_val_end();
|
||||
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||
} else arguments[key] = value_json->json;
|
||||
} else builder.move_to(val_start);
|
||||
}
|
||||
|
||||
// If not, parse as plain text
|
||||
if (val_start == builder.pos()) {
|
||||
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||
auto &value_str = value_plain->prelude;
|
||||
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
arguments[key] = value_str;
|
||||
} else {
|
||||
if (form.trim_raw_argval) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||
} else {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||
}
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consume closing tag
|
||||
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.tool_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||
// Add the parsed tool call
|
||||
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||
}
|
||||
recovery = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||
}
|
||||
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
} else {
|
||||
if (all_space(form.scope_end)) return true;
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size())
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||
auto pos = pos_;
|
||||
auto tsize = result_.tool_calls.size();
|
||||
try { return parse_xml_tool_calls(*this, form); }
|
||||
catch (const xml_toolcall_syntax_exception&) {}
|
||||
move_to(pos);
|
||||
result_.tool_calls.resize(tsize);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||
constexpr auto rstrip = [](std::string &s) {
|
||||
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||
};
|
||||
// Erase substring from l to r, along with additional spaces nearby
|
||||
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||
++l;
|
||||
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||
if (l < r) str[l] = '\n';
|
||||
if (l + 1 < r) str[l + 1] = '\n';
|
||||
if (l != 0) l += 2;
|
||||
str.erase(l, r - l);
|
||||
return l;
|
||||
};
|
||||
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||
auto best_match = content.size();
|
||||
for (auto pattern: list) {
|
||||
if (pattern.size() == 0) continue;
|
||||
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||
auto match_len = content.size() - match_idx;
|
||||
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||
best_match = match_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.size() > best_match) {
|
||||
content.erase(best_match);
|
||||
}
|
||||
};
|
||||
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||
return trim_suffix(content, {
|
||||
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||
form.scope_end
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
// Trim leading spaces without affecting keyword matching
|
||||
static const common_regex spaces_regex("\\s*");
|
||||
{
|
||||
auto tc = builder.consume_regex(spaces_regex);
|
||||
auto spaces = builder.str(tc.groups[0]);
|
||||
auto s1 = spaces.size();
|
||||
trim_potential_partial_word(spaces);
|
||||
auto s2 = spaces.size();
|
||||
builder.move_to(builder.pos() - (s1 - s2));
|
||||
}
|
||||
|
||||
// Parse content
|
||||
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||
std::string unclosed_reasoning_content("");
|
||||
for (;;) {
|
||||
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||
std::string content;
|
||||
std::string tool_call_start;
|
||||
|
||||
if (tc) {
|
||||
content = std::move(tc->prelude);
|
||||
tool_call_start = builder.str(tc->groups[0]);
|
||||
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||
} else {
|
||||
content = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(content);
|
||||
}
|
||||
|
||||
// Handle unclosed think block
|
||||
if (reasoning_unclosed) {
|
||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||
unclosed_reasoning_content += content;
|
||||
if (!(form.allow_toolcall_in_think && tc)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
reasoning_unclosed = false;
|
||||
std::string reasoning_content;
|
||||
if (pos == std::string::npos) {
|
||||
reasoning_content = std::move(content);
|
||||
} else {
|
||||
reasoning_content = content.substr(0, pos);
|
||||
content.erase(0, pos + end_think.size());
|
||||
}
|
||||
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||
rstrip(reasoning_content);
|
||||
trim_potential_partial_word(reasoning_content);
|
||||
rstrip(reasoning_content);
|
||||
if (reasoning_content.empty()) {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
if (unclosed_reasoning_content.empty()) continue;
|
||||
}
|
||||
}
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(start_think);
|
||||
builder.add_content(unclosed_reasoning_content);
|
||||
builder.add_content(reasoning_content);
|
||||
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||
builder.add_content(end_think);
|
||||
} else {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple think block
|
||||
bool toolcall_in_think = false;
|
||||
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||
} else {
|
||||
think_start = think_end + end_think.size() - 1;
|
||||
}
|
||||
} else {
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
// This <tool_call> start is in thinking block
|
||||
if (form.allow_toolcall_in_think) {
|
||||
unclosed_reasoning_content = content.substr(think_start + start_think.size());
|
||||
} else {
|
||||
unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start;
|
||||
}
|
||||
reasoning_unclosed = true;
|
||||
content.resize(think_start);
|
||||
toolcall_in_think = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
rstrip(content);
|
||||
// Handle unclosed </think> token from content: delete all </think> token
|
||||
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||
while (pos != std::string::npos) {
|
||||
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||
pos = content.rfind(end_think, pos);
|
||||
}
|
||||
}
|
||||
// Strip if needed
|
||||
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||
content = string_strip(content);
|
||||
}
|
||||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (builder.pos() == builder.input().size() && builder.is_partial()) {
|
||||
if (unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
} else {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
}
|
||||
}
|
||||
|
||||
// consume unclosed_reasoning_content if allow_toolcall_in_think is set
|
||||
if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
} else {
|
||||
if (content.empty()) {
|
||||
content = start_think + unclosed_reasoning_content;
|
||||
} else {
|
||||
content += "\n\n" + start_think;
|
||||
content += unclosed_reasoning_content;
|
||||
}
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
|
||||
// Add content
|
||||
if (!content.empty()) {
|
||||
// If there are multiple content blocks
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
builder.add_content(content);
|
||||
}
|
||||
|
||||
// This <tool_call> start is in thinking block and toolcall_in_think not set, skip this tool call
|
||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// There is no tool call and all content is parsed
|
||||
if (!tc) {
|
||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||
if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed);
|
||||
break;
|
||||
}
|
||||
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (builder.try_consume_xml_tool_calls(form)) {
|
||||
auto end_of_tool = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() != builder.input().size()) {
|
||||
builder.move_to(end_of_tool);
|
||||
if (!builder.result().content.empty()) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static const common_regex next_char_regex(".");
|
||||
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||
rstrip(c);
|
||||
builder.add_content(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
*/
|
||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Sample config:
|
||||
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||
struct xml_tool_call_format {
|
||||
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||
std::string key_start; // <parameter name=\" // <arg_key>
|
||||
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||
// Defaults to std::nullopt, both will be allowed.
|
||||
std::optional<bool> raw_argval = std::nullopt;
|
||||
std::optional<std::string> last_val_end = std::nullopt;
|
||||
std::optional<std::string> last_tool_end = std::nullopt;
|
||||
bool trim_raw_argval = false;
|
||||
bool allow_toolcall_in_think = false;
|
||||
};
|
||||
|
||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,133 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "chat-parser-xml-toolcall.h"
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
public:
|
||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_parser_params syntax_; // TODO: rename to params
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
const common_chat_parser_params & syntax() const { return syntax_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
throw std::runtime_error("Invalid position!");
|
||||
}
|
||||
pos_ = pos;
|
||||
}
|
||||
void move_back(size_t n) {
|
||||
if (pos_ < n) {
|
||||
throw std::runtime_error("Can't move back that far!");
|
||||
}
|
||||
pos_ -= n;
|
||||
}
|
||||
|
||||
// Get the substring of the input at the given range
|
||||
std::string str(const common_string_range & rng) const;
|
||||
|
||||
// Appends to the result.content field
|
||||
void add_content(const std::string & content);
|
||||
|
||||
// Appends to the result.reasoning_content field
|
||||
void add_reasoning_content(const std::string & reasoning_content);
|
||||
|
||||
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||
|
||||
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
|
||||
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
void finish();
|
||||
|
||||
bool consume_spaces();
|
||||
|
||||
void consume_literal(const std::string & literal);
|
||||
|
||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||
|
||||
std::string consume_rest();
|
||||
|
||||
struct find_regex_result {
|
||||
std::string prelude;
|
||||
std::vector<common_string_range> groups;
|
||||
};
|
||||
|
||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||
|
||||
bool try_consume_literal(const std::string & literal);
|
||||
|
||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||
|
||||
find_regex_result consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<common_json> try_consume_json();
|
||||
common_json consume_json();
|
||||
|
||||
struct consume_json_result {
|
||||
nlohmann::ordered_json value;
|
||||
bool is_partial;
|
||||
};
|
||||
|
||||
/*
|
||||
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||
|
||||
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||
|
||||
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||
*/
|
||||
consume_json_result consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||
|
||||
// Parse content uses reasoning and XML-Style tool call
|
||||
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||
|
||||
void clear_tools();
|
||||
};
|
||||
|
|
@ -1,13 +1,17 @@
|
|||
#include "chat-peg-parser.h"
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "ggml.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::json;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
|
||||
int count = 0;
|
||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
||||
if (max != -1 && count <= max) {
|
||||
if (max != -1 && count >= max) {
|
||||
break;
|
||||
}
|
||||
sv.remove_suffix(1);
|
||||
|
|
@ -16,109 +20,753 @@ static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
|
|||
return sv;
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
static std::string_view trim_leading_space(std::string_view sv, int max = -1) {
|
||||
int count = 0;
|
||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.front()))) {
|
||||
if (max != -1 && count >= max) {
|
||||
break;
|
||||
}
|
||||
sv.remove_prefix(1);
|
||||
count++;
|
||||
}
|
||||
return sv;
|
||||
}
|
||||
|
||||
static std::string_view trim(std::string_view sv) {
|
||||
return trim_trailing_space(trim_leading_space(sv, 1));
|
||||
}
|
||||
|
||||
// Count the number of unclosed '{' braces in a JSON-like string,
|
||||
// properly skipping braces inside quoted strings.
|
||||
static int json_brace_depth(const std::string & s) {
|
||||
int depth = 0;
|
||||
bool in_string = false;
|
||||
bool escaped = false;
|
||||
for (char c : s) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (c == '\\' && in_string) {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (c == '"') {
|
||||
in_string = !in_string;
|
||||
continue;
|
||||
}
|
||||
if (!in_string) {
|
||||
if (c == '{') {
|
||||
depth++;
|
||||
} else if (c == '}') {
|
||||
depth--;
|
||||
}
|
||||
}
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
|
||||
// JSON-escape a string and return the inner content (without surrounding quotes).
|
||||
static std::string escape_json_string_inner(const std::string & s) {
|
||||
std::string escaped = json(s).dump();
|
||||
if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') {
|
||||
return escaped.substr(1, escaped.size() - 2);
|
||||
}
|
||||
return escaped;
|
||||
}
|
||||
|
||||
// Convert Python-style single-quoted strings to JSON double-quoted strings
|
||||
// Only converts outer string delimiters, properly handling escape sequences:
|
||||
// - {'key': 'value'} -> {"key": "value"}
|
||||
// - {'code': 'print(\'hello\')'} -> {"code": "print('hello')"}
|
||||
// - {'msg': 'He said "hi"'} -> {"msg": "He said \"hi\""}
|
||||
static std::string normalize_quotes_to_json(const std::string & input) {
|
||||
std::string result;
|
||||
result.reserve(input.size() + 16); // May need extra space for escaping
|
||||
|
||||
bool in_single_quoted = false;
|
||||
bool in_double_quoted = false;
|
||||
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
char c = input[i];
|
||||
|
||||
// Handle escape sequences
|
||||
if (c == '\\' && i + 1 < input.size()) {
|
||||
char next = input[i + 1];
|
||||
|
||||
if (in_single_quoted) {
|
||||
// Inside a single-quoted string being converted to double quotes
|
||||
if (next == '\'') {
|
||||
// \' -> ' (escaped single quote becomes unescaped in double-quoted string)
|
||||
result += '\'';
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (next == '"') {
|
||||
// \" stays as \" (already escaped, works in double-quoted string)
|
||||
result += "\\\"";
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
// Other escapes (\n, \\, etc.): pass through both characters
|
||||
result += c;
|
||||
result += next;
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (in_double_quoted) {
|
||||
// Inside a double-quoted string - pass through escape sequences as-is
|
||||
result += c;
|
||||
result += next;
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Outside any string - just pass through the backslash
|
||||
result += c;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle quote characters
|
||||
if (c == '"') {
|
||||
if (in_single_quoted) {
|
||||
// Unescaped double quote inside single-quoted string -> must escape for JSON
|
||||
result += "\\\"";
|
||||
} else {
|
||||
// Double quote as string delimiter or outside strings
|
||||
in_double_quoted = !in_double_quoted;
|
||||
result += c;
|
||||
}
|
||||
} else if (c == '\'') {
|
||||
if (in_double_quoted) {
|
||||
// Single quote inside double-quoted string -> pass through
|
||||
result += c;
|
||||
} else if (in_single_quoted) {
|
||||
// Closing single quote -> convert to double quote
|
||||
in_single_quoted = false;
|
||||
result += '"';
|
||||
} else {
|
||||
// Opening single quote -> convert to double quote
|
||||
in_single_quoted = true;
|
||||
result += '"';
|
||||
}
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
arena.visit(result, [this](const common_peg_ast_node & node) {
|
||||
map(node);
|
||||
if (!node.tag.empty()) {
|
||||
tags[node.tag] = std::string(node.text);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, bool is_partial) const {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto parse_result = arena.parse(ctx);
|
||||
|
||||
if (is_reasoning) {
|
||||
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
||||
tag_based_peg_mapper mapper;
|
||||
mapper.from_ast(ctx.ast, parse_result);
|
||||
|
||||
return { std::move(parse_result), std::move(mapper.tags) };
|
||||
}
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
|
||||
if (input.empty()) {
|
||||
return parse_and_extract(input, false);
|
||||
}
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
common_peg_parse_context ctx(input, false);
|
||||
ctx.debug = debug;
|
||||
auto parse_result = arena.parse(ctx, i);
|
||||
if (parse_result.success() || i == input.size() - 1) {
|
||||
tag_based_peg_mapper mapper;
|
||||
mapper.from_ast(ctx.ast, parse_result);
|
||||
return { std::move(parse_result), std::move(mapper.tags) };
|
||||
}
|
||||
}
|
||||
GGML_ABORT("Should not happen");
|
||||
}
|
||||
|
||||
if (is_content) {
|
||||
result.content = std::string(trim_trailing_space(node.text));
|
||||
tagged_peg_parser build_tagged_peg_parser(
|
||||
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn) {
|
||||
common_peg_parser_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return { builder.build() };
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::tag_with_safe_content(const std::string & tag_name,
|
||||
const std::string & marker,
|
||||
const common_peg_parser & p) {
|
||||
if (marker.empty()) {
|
||||
return zero_or_more(choice({ p, rule(tag_name, content(any())) }));
|
||||
}
|
||||
auto content_chunk = rule(tag_name, content(negate(literal(marker)) + any() + until(marker)));
|
||||
return zero_or_more(choice({ p, content_chunk }));
|
||||
}
|
||||
|
||||
std::string & common_chat_peg_mapper::args_target() {
|
||||
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
const common_peg_parse_result & parse_result_arg) {
|
||||
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
|
||||
// Flush any pending tool call that was started but never got a name
|
||||
// This happens during partial parsing when the tool call is incomplete
|
||||
if (pending_tool_call.has_value() && !pending_tool_call->name.empty()) {
|
||||
if (!args_buffer.empty()) {
|
||||
pending_tool_call->arguments = args_buffer;
|
||||
}
|
||||
if (closing_quote_pending && !pending_tool_call->arguments.empty()) {
|
||||
pending_tool_call->arguments += "\"";
|
||||
}
|
||||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
pending_tool_call.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
// Handle reasoning/content tags
|
||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
||||
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
||||
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
||||
if (is_reasoning) { // GPT OSS can have more than 1 reasoning block, so concatenate here
|
||||
result.reasoning_content += std::string(node.text);
|
||||
}
|
||||
|
||||
if (is_content) {
|
||||
// Concatenate content from multiple content nodes (e.g., when reasoning markers
|
||||
// are preserved before content markers in reasoning_format=NONE mode)
|
||||
result.content += std::string(node.text);
|
||||
}
|
||||
|
||||
// Handle tool-related tags (supporting both JSON and tagged formats)
|
||||
bool is_tool_open = node.tag == common_chat_peg_builder::TOOL_OPEN;
|
||||
bool is_tool_close = node.tag == common_chat_peg_builder::TOOL_CLOSE;
|
||||
bool is_tool_name = node.tag == common_chat_peg_builder::TOOL_NAME;
|
||||
bool is_tool_id = node.tag == common_chat_peg_builder::TOOL_ID;
|
||||
bool is_tool_args = node.tag == common_chat_peg_builder::TOOL_ARGS;
|
||||
bool is_arg_open = node.tag == common_chat_peg_builder::TOOL_ARG_OPEN;
|
||||
bool is_arg_close = node.tag == common_chat_peg_builder::TOOL_ARG_CLOSE;
|
||||
bool is_arg_name = node.tag == common_chat_peg_builder::TOOL_ARG_NAME;
|
||||
bool is_arg_value = node.tag == common_chat_peg_builder::TOOL_ARG_VALUE;
|
||||
bool is_arg_string_value = node.tag == common_chat_peg_builder::TOOL_ARG_STRING_VALUE;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
pending_tool_call = common_chat_tool_call();
|
||||
current_tool = &pending_tool_call.value();
|
||||
arg_count = 0;
|
||||
args_buffer.clear();
|
||||
closing_quote_pending = false;
|
||||
}
|
||||
|
||||
if (is_tool_id && current_tool) {
|
||||
current_tool->id = std::string(trim_trailing_space(node.text));
|
||||
auto text = trim_trailing_space(node.text);
|
||||
if (text.size() >= 2 && text.front() == '"' && text.back() == '"') {
|
||||
text = text.substr(1, text.size() - 2);
|
||||
}
|
||||
current_tool->id = std::string(text);
|
||||
}
|
||||
|
||||
if (is_tool_name && current_tool) {
|
||||
current_tool->name = std::string(trim_trailing_space(node.text));
|
||||
// Now that we have the name, populate the arguments from the buffer
|
||||
if (!args_buffer.empty()) {
|
||||
current_tool->arguments = args_buffer;
|
||||
args_buffer.clear();
|
||||
} else if (current_tool->arguments.empty()) {
|
||||
current_tool->arguments = "{";
|
||||
}
|
||||
// Add the tool call to results so streaming can see it
|
||||
if (pending_tool_call.has_value()) {
|
||||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
pending_tool_call.reset();
|
||||
current_tool = &result.tool_calls.back();
|
||||
}
|
||||
}
|
||||
|
||||
if (is_tool_args && current_tool) {
|
||||
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
||||
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
||||
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
||||
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
||||
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
||||
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
||||
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
arg_count = 0;
|
||||
}
|
||||
|
||||
if (is_tool_name) {
|
||||
current_tool->name = std::string(node.text);
|
||||
current_tool->arguments = "{";
|
||||
// For JSON format: arguments come as a complete JSON object
|
||||
// For tagged format: built up from individual arg_name/arg_value nodes
|
||||
auto text = trim_trailing_space(node.text);
|
||||
if (!text.empty() && text.front() == '{') {
|
||||
args_target() = std::string(text);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_arg_open) {
|
||||
needs_closing_quote = false;
|
||||
closing_quote_pending = false;
|
||||
}
|
||||
|
||||
if (is_arg_name && current_tool) {
|
||||
std::string arg_entry;
|
||||
if (arg_count > 0) {
|
||||
current_tool->arguments += ",";
|
||||
arg_entry = ",";
|
||||
}
|
||||
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
||||
arg_entry += json(trim(node.text)).dump() + ":";
|
||||
++arg_count;
|
||||
|
||||
auto & target = args_target();
|
||||
if (target.empty()) {
|
||||
target = "{";
|
||||
}
|
||||
target += arg_entry;
|
||||
}
|
||||
|
||||
if (is_arg_string && current_tool) {
|
||||
// Serialize to JSON, but exclude the end quote
|
||||
std::string dumped = json(trim_trailing_space(node.text)).dump();
|
||||
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
||||
needs_closing_quote = true;
|
||||
if ((is_arg_value || is_arg_string_value) && current_tool) {
|
||||
std::string value_content = std::string(trim_trailing_space(trim_leading_space(node.text, 1), 1));
|
||||
|
||||
std::string value_to_add;
|
||||
if (value_content.empty() && is_arg_string_value) {
|
||||
// Empty string value - arg_close will add the closing quote
|
||||
value_to_add = "\"";
|
||||
closing_quote_pending = true;
|
||||
} else if (!value_content.empty() && is_arg_string_value) {
|
||||
// Schema declares this as string type - always treat as literal string value
|
||||
if (!closing_quote_pending) {
|
||||
value_to_add = "\"";
|
||||
closing_quote_pending = true;
|
||||
}
|
||||
value_to_add += escape_json_string_inner(value_content);
|
||||
} else if (!value_content.empty()) {
|
||||
// For potential containers, normalize Python-style single quotes to JSON double quotes
|
||||
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
|
||||
if (is_potential_container) {
|
||||
value_content = normalize_quotes_to_json(value_content);
|
||||
}
|
||||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
try {
|
||||
json parsed = json::parse(value_content);
|
||||
if (parsed.is_string()) {
|
||||
// Don't add closing quote yet (added by arg_close) for monotonic streaming
|
||||
std::string escaped = parsed.dump();
|
||||
if (!escaped.empty() && escaped.back() == '"') {
|
||||
escaped.pop_back();
|
||||
}
|
||||
value_to_add = escaped;
|
||||
closing_quote_pending = true;
|
||||
} else {
|
||||
// Non-string values: use raw content to preserve whitespace for monotonicity
|
||||
value_to_add = value_content;
|
||||
}
|
||||
} catch (...) {
|
||||
if (node.is_partial && is_potential_container) {
|
||||
// Partial container: pass through the already-normalized content
|
||||
value_to_add = value_content;
|
||||
} else {
|
||||
// Not valid JSON - treat as string value
|
||||
if (!closing_quote_pending) {
|
||||
value_to_add = "\"";
|
||||
closing_quote_pending = true;
|
||||
}
|
||||
value_to_add += escape_json_string_inner(value_content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
args_target() += value_to_add;
|
||||
}
|
||||
|
||||
if (is_arg_close && current_tool) {
|
||||
if (needs_closing_quote) {
|
||||
current_tool->arguments += "\"";
|
||||
needs_closing_quote = false;
|
||||
if (closing_quote_pending) {
|
||||
args_target() += "\"";
|
||||
closing_quote_pending = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_arg_json && current_tool) {
|
||||
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_close && current_tool) {
|
||||
if (needs_closing_quote) {
|
||||
current_tool->arguments += "\"";
|
||||
needs_closing_quote = false;
|
||||
// Flush buffer to arguments if tool name was never seen
|
||||
if (current_tool->name.empty() && !args_buffer.empty()) {
|
||||
current_tool->arguments = args_buffer;
|
||||
args_buffer.clear();
|
||||
}
|
||||
// Close any pending string quote
|
||||
if (closing_quote_pending) {
|
||||
current_tool->arguments += "\"";
|
||||
closing_quote_pending = false;
|
||||
}
|
||||
// Close any unclosed braces (accounts for nested objects)
|
||||
for (int d = json_brace_depth(current_tool->arguments); d > 0; d--) {
|
||||
current_tool->arguments += "}";
|
||||
}
|
||||
// Add tool call to results if named; otherwise discard
|
||||
if (pending_tool_call.has_value()) {
|
||||
if (!current_tool->name.empty()) {
|
||||
result.tool_calls.push_back(pending_tool_call.value());
|
||||
}
|
||||
pending_tool_call.reset();
|
||||
}
|
||||
current_tool->arguments += "}";
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
||||
const std::map<std::string, std::string> & markers,
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
return eps();
|
||||
}
|
||||
|
||||
// Extract markers with defaults
|
||||
auto get_marker = [&markers](const std::string & key, const std::string & default_val = "") -> std::string {
|
||||
auto it = markers.find(key);
|
||||
return it != markers.end() ? it->second : default_val;
|
||||
};
|
||||
|
||||
std::string section_start = get_marker("tool_call_start_marker", "<tool_call>");
|
||||
std::string section_end = get_marker("tool_call_end_marker", "</tool_call>");
|
||||
std::string func_opener = get_marker("function_opener", "<function=");
|
||||
std::string func_name_suffix = get_marker("function_name_suffix", ">");
|
||||
std::string func_closer = get_marker("function_closer", "</function>");
|
||||
std::string param_key_prefix = get_marker("parameter_key_prefix", "<param=");
|
||||
std::string param_key_suffix = get_marker("parameter_key_suffix", ">");
|
||||
std::string param_closer = get_marker("parameter_closer", "</param>");
|
||||
|
||||
// Build tool choices for tagged format
|
||||
auto tool_choices = choice();
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
// Build argument parsers
|
||||
auto args = eps();
|
||||
if (params.contains("properties") && !params["properties"].empty()) {
|
||||
auto arg_choice = choice();
|
||||
for (const auto & el : params["properties"].items()) {
|
||||
const std::string & prop_name = el.key();
|
||||
|
||||
auto arg_name_parser =
|
||||
choice({ literal(prop_name), literal("\"" + prop_name + "\""), literal("'" + prop_name + "'") });
|
||||
|
||||
auto arg_rule = tool_arg(tool_arg_open(literal(param_key_prefix)) + tool_arg_name(arg_name_parser) +
|
||||
literal(param_key_suffix) + tool_arg_value(until(param_closer)) +
|
||||
tool_arg_close(literal(param_closer)));
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
args = zero_or_more(arg_choice + space());
|
||||
}
|
||||
|
||||
// Build function parser: <function=name>args</function>
|
||||
auto tool_parser = tool(tool_open(literal(func_opener) + tool_name(literal(name)) + literal(func_name_suffix)) +
|
||||
space() + tool_args(args) + space() + tool_close(literal(func_closer)));
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool_parser);
|
||||
}
|
||||
|
||||
// Build the section with markers
|
||||
auto section =
|
||||
parallel_tool_calls ?
|
||||
trigger_rule("tool-call", literal(section_start) + space() + one_or_more(tool_choices + space()) +
|
||||
literal(section_end)) :
|
||||
trigger_rule("tool-call", literal(section_start) + space() + tool_choices + space() + literal(section_end));
|
||||
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
||||
// Helper: Parse dot notation key into prefix and field name
|
||||
static std::pair<std::string, std::string> parse_key_spec(const std::string & key) {
|
||||
auto dot_pos = key.find('.');
|
||||
if (dot_pos == std::string::npos) {
|
||||
return {"", key}; // Top-level field
|
||||
}
|
||||
return {key.substr(0, dot_pos), key.substr(dot_pos + 1)};
|
||||
}
|
||||
|
||||
// Mode 1: function_is_key — parse {"function_name": {...}}
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
const nlohmann::json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
// Build inner object fields
|
||||
std::vector<common_peg_parser> inner_fields;
|
||||
|
||||
if (!call_id_key.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
);
|
||||
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
|
||||
}
|
||||
|
||||
if (!gen_call_id_key.empty()) {
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
inner_fields.push_back(optional(gen_id_parser + space() + optional(literal(",") + space())));
|
||||
}
|
||||
|
||||
// Arguments — either wrapped in args_key or parsed directly
|
||||
common_peg_parser args_parser = eps();
|
||||
if (args_key.empty()) {
|
||||
args_parser = tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
} else {
|
||||
args_parser = literal("\"" + effective_args_key + "\"") + space() + literal(":") + space() +
|
||||
tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
}
|
||||
inner_fields.push_back(args_parser);
|
||||
|
||||
// Build inner object parser
|
||||
common_peg_parser inner_object = eps();
|
||||
if (args_key.empty() && inner_fields.size() == 1) {
|
||||
inner_object = inner_fields[0];
|
||||
} else {
|
||||
inner_object = literal("{") + space();
|
||||
for (size_t i = 0; i < inner_fields.size(); i++) {
|
||||
inner_object = inner_object + inner_fields[i];
|
||||
if (i < inner_fields.size() - 1) {
|
||||
inner_object = inner_object + space();
|
||||
}
|
||||
}
|
||||
inner_object = inner_object + space() + literal("}");
|
||||
}
|
||||
|
||||
auto tool_parser = tool(
|
||||
tool_open(literal("{")) + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"") +
|
||||
space() + literal(":") + space() +
|
||||
inner_object +
|
||||
space() + tool_close(literal("}"))
|
||||
);
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool_parser);
|
||||
}
|
||||
|
||||
return tool_choices;
|
||||
}
|
||||
|
||||
// Mode 2: Nested keys (dot notation like "function.name")
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
auto name_spec = parse_key_spec(effective_name_key);
|
||||
auto args_spec = parse_key_spec(effective_args_key);
|
||||
|
||||
std::string nested_prefix = !name_spec.first.empty() ? name_spec.first : args_spec.first;
|
||||
std::string nested_name_field = !name_spec.first.empty() ? name_spec.second : effective_name_key;
|
||||
std::string nested_args_field = !args_spec.first.empty() ? args_spec.second : effective_args_key;
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() +
|
||||
tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
|
||||
auto nested_object = literal("{") + space() +
|
||||
nested_name + space() + literal(",") + space() +
|
||||
nested_args +
|
||||
space() + literal("}");
|
||||
|
||||
// Format: { id?, "function": {...} }
|
||||
auto tool_parser_body = tool_open(literal("{")) + space();
|
||||
|
||||
if (!call_id_key.empty()) {
|
||||
auto id_spec = parse_key_spec(call_id_key);
|
||||
if (id_spec.first.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
);
|
||||
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
|
||||
}
|
||||
}
|
||||
|
||||
if (!gen_call_id_key.empty()) {
|
||||
auto gen_id_spec = parse_key_spec(gen_call_id_key);
|
||||
if (gen_id_spec.first.empty()) {
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
tool_parser_body = tool_parser_body + optional(gen_id_parser + space() + literal(",") + space());
|
||||
}
|
||||
}
|
||||
|
||||
auto nested_field = literal("\"" + nested_prefix + "\"") + space() + literal(":") + space() + nested_object;
|
||||
tool_parser_body = tool_parser_body + nested_field + space() + tool_close(literal("}"));
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool(tool_parser_body));
|
||||
}
|
||||
|
||||
return tool_choices;
|
||||
}
|
||||
|
||||
// Mode 3: Flat keys with optional ID fields and parameter ordering
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key,
|
||||
const std::vector<std::string> & parameters_order) {
|
||||
|
||||
auto tool_choices = choice();
|
||||
auto name_key_parser = literal("\"" + effective_name_key + "\"");
|
||||
auto args_key_parser = literal("\"" + effective_args_key + "\"");
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
auto tool_args_ = args_key_parser + space() + literal(":") + space() +
|
||||
tool_args(schema(json(), "tool-" + name + "-schema", params));
|
||||
|
||||
// Build ID parsers if keys are provided
|
||||
common_peg_parser id_parser = eps();
|
||||
if (!call_id_key.empty()) {
|
||||
id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
common_peg_parser gen_id_parser = eps();
|
||||
if (!gen_call_id_key.empty()) {
|
||||
gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Create (parser, key) pairs for all fields, then sort by parameters_order
|
||||
std::vector<std::pair<common_peg_parser, std::string>> parser_pairs;
|
||||
parser_pairs.emplace_back(tool_name_, effective_name_key);
|
||||
parser_pairs.emplace_back(tool_args_, effective_args_key);
|
||||
if (!call_id_key.empty()) {
|
||||
parser_pairs.emplace_back(optional(id_parser), call_id_key);
|
||||
}
|
||||
if (!gen_call_id_key.empty()) {
|
||||
parser_pairs.emplace_back(optional(gen_id_parser), gen_call_id_key);
|
||||
}
|
||||
|
||||
std::sort(parser_pairs.begin(), parser_pairs.end(),
|
||||
[¶meters_order](const auto & a, const auto & b) {
|
||||
auto pos_a = std::find(parameters_order.begin(), parameters_order.end(), a.second);
|
||||
auto pos_b = std::find(parameters_order.begin(), parameters_order.end(), b.second);
|
||||
size_t idx_a = (pos_a == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_a);
|
||||
size_t idx_b = (pos_b == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_b);
|
||||
return idx_a < idx_b;
|
||||
});
|
||||
|
||||
auto ordered_body = tool_open(literal("{")) + space();
|
||||
for (size_t i = 0; i < parser_pairs.size(); i++) {
|
||||
ordered_body = ordered_body + parser_pairs[i].first;
|
||||
if (i < parser_pairs.size() - 1) {
|
||||
ordered_body = ordered_body + space() + literal(",") + space();
|
||||
}
|
||||
}
|
||||
ordered_body = ordered_body + space() + tool_close(literal("}"));
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool(ordered_body));
|
||||
}
|
||||
|
||||
return tool_choices;
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls,
|
||||
const std::string & name_key,
|
||||
const std::string & args_key,
|
||||
bool array_wrapped,
|
||||
bool function_is_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key,
|
||||
const std::vector<std::string> & parameters_order) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
return eps();
|
||||
}
|
||||
|
||||
std::string effective_name_key = name_key.empty() ? "name" : name_key;
|
||||
std::string effective_args_key = args_key.empty() ? "arguments" : args_key;
|
||||
|
||||
// Dispatch to the appropriate builder based on the JSON layout mode
|
||||
common_peg_parser tool_choices = eps();
|
||||
if (function_is_key) {
|
||||
tool_choices = build_json_tools_function_is_key(tools, args_key, effective_args_key, call_id_key, gen_call_id_key);
|
||||
} else {
|
||||
auto name_spec = parse_key_spec(effective_name_key);
|
||||
auto args_spec = parse_key_spec(effective_args_key);
|
||||
if (!name_spec.first.empty() || !args_spec.first.empty()) {
|
||||
tool_choices = build_json_tools_nested_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key);
|
||||
} else {
|
||||
tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order);
|
||||
}
|
||||
}
|
||||
|
||||
// Build the section with markers
|
||||
auto tool_calls = tool_choices;
|
||||
if (parallel_tool_calls) {
|
||||
tool_calls = tool_calls + zero_or_more(space() + literal(",") + space() + tool_choices);
|
||||
}
|
||||
|
||||
if (array_wrapped) {
|
||||
tool_calls = literal("[") + space() + tool_calls + space() + literal("]");
|
||||
}
|
||||
|
||||
auto section =
|
||||
trigger_rule("tool-call", literal(section_start) + space() + tool_calls + space() + literal(section_end));
|
||||
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,22 +3,9 @@
|
|||
#include "chat.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
public:
|
||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||
static constexpr const char * REASONING = "reasoning";
|
||||
static constexpr const char * CONTENT = "content";
|
||||
|
||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||
common_chat_peg_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
class common_chat_peg_mapper {
|
||||
public:
|
||||
|
|
@ -26,80 +13,164 @@ class common_chat_peg_mapper {
|
|||
|
||||
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
||||
|
||||
virtual ~common_chat_peg_mapper() = default;
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
private:
|
||||
// Tool call handling state
|
||||
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
|
||||
common_chat_tool_call * current_tool = nullptr;
|
||||
int arg_count = 0;
|
||||
bool closing_quote_pending = false;
|
||||
std::string args_buffer; // Buffer to delay arguments until tool name is known
|
||||
|
||||
// Returns a reference to the active argument destination string.
|
||||
// Before tool_name is known, writes go to args_buffer; after, to current_tool->arguments.
|
||||
std::string & args_target();
|
||||
};
|
||||
|
||||
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_ID = "tool-id";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||
struct content_structure;
|
||||
struct tool_call_structure;
|
||||
|
||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
public:
|
||||
// Tag constants (from former common_chat_peg_base_builder)
|
||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||
static constexpr const char * REASONING = "reasoning";
|
||||
static constexpr const char * CONTENT = "content";
|
||||
|
||||
// Tag constants
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_ID = "tool-id";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||
static constexpr const char * TOOL_ARG_VALUE = "tool-arg-value";
|
||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; // For schema-declared string types
|
||||
|
||||
// Low-level tag methods (from former common_chat_peg_base_builder)
|
||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||
|
||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||
|
||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||
|
||||
common_peg_parser tag_with_safe_content(const std::string & tag_name,
|
||||
const std::string & marker,
|
||||
const common_peg_parser & p);
|
||||
|
||||
// Low-level tag methods
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
|
||||
public:
|
||||
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
||||
common_chat_peg_native_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
||||
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
||||
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
||||
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
||||
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
||||
common_peg_parser tool_arg_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); }
|
||||
|
||||
// Use for schema-declared string types - won't be treated as potential JSON container
|
||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); }
|
||||
|
||||
// Legacy-compatible helper for building standard JSON tool calls
|
||||
// Used by tests and manual parsers
|
||||
// name_key/args_key: JSON key names for function name and arguments
|
||||
// Empty or "name"/"arguments" will accept both common variations
|
||||
// Supports dot notation for nested objects (e.g., "function.name")
|
||||
// array_wrapped: if true, tool calls are wrapped in JSON array [...]
|
||||
// function_is_key: if true, function name is the JSON key (e.g., {"func_name": {...}})
|
||||
// call_id_key: JSON key for string call ID (e.g., "id")
|
||||
// gen_call_id_key: JSON key for generated integer call ID (e.g., "tool_call_id")
|
||||
// parameters_order: order in which JSON fields should be parsed
|
||||
common_peg_parser standard_json_tools(const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls,
|
||||
const std::string & name_key = "",
|
||||
const std::string & args_key = "",
|
||||
bool array_wrapped = false,
|
||||
bool function_is_key = false,
|
||||
const std::string & call_id_key = "",
|
||||
const std::string & gen_call_id_key = "",
|
||||
const std::vector<std::string> & parameters_order = {});
|
||||
|
||||
// Legacy-compatible helper for building XML/tagged style tool calls
|
||||
// Used by tests and manual parsers
|
||||
common_peg_parser standard_constructed_tools(const std::map<std::string, std::string> & markers,
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls);
|
||||
|
||||
private:
|
||||
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
|
||||
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
|
||||
common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
|
||||
common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key,
|
||||
const std::vector<std::string> & parameters_order);
|
||||
};
|
||||
|
||||
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
int arg_count = 0;
|
||||
bool needs_closing_quote = false;
|
||||
|
||||
public:
|
||||
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
||||
common_chat_peg_constructed_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
inline common_peg_arena build_chat_peg_parser(
|
||||
const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||
common_chat_peg_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class tag_based_peg_mapper {
|
||||
public:
|
||||
std::map<std::string, std::string> tags;
|
||||
|
||||
void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
};
|
||||
|
||||
struct tagged_parse_result {
|
||||
common_peg_parse_result result;
|
||||
std::map<std::string, std::string> tags;
|
||||
};
|
||||
|
||||
struct tagged_peg_parser {
|
||||
common_peg_arena arena;
|
||||
bool debug = false;
|
||||
|
||||
tagged_peg_parser & withDebug() {
|
||||
debug = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_peg_parser & withoutDebug() {
|
||||
debug = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const;
|
||||
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
|
||||
};
|
||||
|
||||
tagged_peg_parser build_tagged_peg_parser(
|
||||
const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
||||
|
||||
|
|
|
|||
2993
common/chat.cpp
2993
common/chat.cpp
File diff suppressed because it is too large
Load Diff
253
common/chat.h
253
common/chat.h
|
|
@ -3,17 +3,30 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
using chat_template_caps = jinja::caps;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
struct templates_params;
|
||||
} // namespace autoparser
|
||||
|
||||
struct common_chat_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
|
|
@ -38,21 +51,85 @@ struct common_chat_msg_content_part {
|
|||
}
|
||||
};
|
||||
|
||||
struct common_chat_template {
|
||||
jinja::program prog;
|
||||
std::string bos_tok;
|
||||
std::string eos_tok;
|
||||
std::string src;
|
||||
chat_template_caps caps;
|
||||
|
||||
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(src);
|
||||
this->prog = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
this->src = lexer_res.source;
|
||||
this->bos_tok = bos_token;
|
||||
this->eos_tok = eos_token;
|
||||
|
||||
this->caps = jinja::caps_get(prog);
|
||||
// LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str());
|
||||
}
|
||||
|
||||
const std::string & source() const { return src; }
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_chat_msg_content_part> content_parts;
|
||||
std::vector<common_chat_tool_call> tool_calls;
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
std::vector<common_chat_tool_call> tool_calls;
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() &&
|
||||
tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
|
||||
void set_tool_call_ids(std::vector<std::string> & ids_cache,
|
||||
const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
auto id = tool_calls[i].id;
|
||||
|
|
@ -64,32 +141,28 @@ struct common_chat_msg {
|
|||
tool_calls[i].id = ids_cache[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const common_chat_msg & other) const {
|
||||
return role == other.role
|
||||
&& content == other.content
|
||||
&& content_parts == other.content_parts
|
||||
&& tool_calls == other.tool_calls
|
||||
&& reasoning_content == other.reasoning_content
|
||||
&& tool_name == other.tool_name
|
||||
&& tool_call_id == other.tool_call_id;
|
||||
}
|
||||
bool operator!=(const common_chat_msg & other) const {
|
||||
return !(*this == other);
|
||||
return role == other.role && content == other.content && content_parts == other.content_parts &&
|
||||
tool_calls == other.tool_calls && reasoning_content == other.reasoning_content &&
|
||||
tool_name == other.tool_name && tool_call_id == other.tool_call_id;
|
||||
}
|
||||
|
||||
bool operator!=(const common_chat_msg & other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
struct common_chat_msg_diff {
|
||||
std::string reasoning_content_delta;
|
||||
std::string content_delta;
|
||||
size_t tool_call_index = std::string::npos;
|
||||
std::string reasoning_content_delta;
|
||||
std::string content_delta;
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv,
|
||||
const common_chat_msg & msg_new);
|
||||
|
||||
bool operator==(const common_chat_msg_diff & other) const {
|
||||
return content_delta == other.content_delta
|
||||
&& tool_call_index == other.tool_call_index
|
||||
&& tool_call_delta == other.tool_call_delta;
|
||||
return content_delta == other.content_delta && tool_call_index == other.tool_call_index &&
|
||||
tool_call_delta == other.tool_call_delta;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -107,64 +180,39 @@ enum common_chat_tool_choice {
|
|||
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_MAGISTRAL,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||
COMMON_CHAT_FORMAT_EXAONE_MOE,
|
||||
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
// Parameters below only supported when use_jinja is true
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool supports_thinking = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
|
|
@ -174,13 +222,14 @@ struct common_chat_params {
|
|||
// per-message parsing syntax
|
||||
// should be derived from common_chat_params
|
||||
struct common_chat_parser_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
bool debug = false; // Enable debug output for PEG parser
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
|
|
@ -193,45 +242,42 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
|||
|
||||
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||
|
||||
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||
struct common_chat_templates_deleter {
|
||||
void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
std::string common_chat_format_single(const struct common_chat_templates * tmpls,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
std::string common_chat_format_example(const struct common_chat_templates * tmpls,
|
||||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
|
|
@ -250,3 +296,10 @@ nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_
|
|||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
|
|
|
|||
|
|
@ -676,7 +676,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
|||
|
||||
size_t offset = 0;
|
||||
while (offset < filename.size()) {
|
||||
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
|
||||
utf8_parse_result result = common_parse_utf8_codepoint(filename, offset);
|
||||
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -516,14 +516,15 @@ struct common_params {
|
|||
std::string cls_sep = "\t"; // separator of classification sequences
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
|
|
@ -545,6 +546,7 @@ struct common_params {
|
|||
|
||||
// webui configs
|
||||
bool webui = true;
|
||||
bool webui_mcp_proxy = false;
|
||||
std::string webui_config_json;
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "log.h"
|
||||
#include "value.h"
|
||||
#include "runtime.h"
|
||||
#include "caps.h"
|
||||
|
|
@ -36,12 +37,16 @@ static void caps_try_execute(jinja::program & prog,
|
|||
auto tools = ctx.get_val("tools");
|
||||
|
||||
bool success = false;
|
||||
std::string result;
|
||||
try {
|
||||
jinja::runtime runtime(ctx);
|
||||
runtime.execute(prog);
|
||||
auto results = runtime.execute(prog);
|
||||
auto parts = jinja::runtime::gather_string_parts(results);
|
||||
result = parts->as_string().str();
|
||||
success = true;
|
||||
} catch (const std::exception & e) {
|
||||
JJ_DEBUG("Exception during execution: %s", e.what());
|
||||
result = "";
|
||||
// ignore exceptions during capability analysis
|
||||
}
|
||||
|
||||
|
|
@ -90,6 +95,8 @@ caps caps_get(jinja::program & prog) {
|
|||
return v->stats.ops.find(op_name) != v->stats.ops.end();
|
||||
};
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: typed content");
|
||||
|
||||
// case: typed content support
|
||||
caps_try_execute(
|
||||
prog,
|
||||
|
|
@ -120,6 +127,7 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: system prompt");
|
||||
|
||||
// case: system prompt support
|
||||
caps_try_execute(
|
||||
|
|
@ -150,7 +158,9 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
// case: tools support
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: single tool support");
|
||||
|
||||
// case: tools support: single call
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
|
|
@ -162,10 +172,10 @@ caps caps_get(jinja::program & prog) {
|
|||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
{"content", ""}, // Some templates expect content to be empty with tool calls
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call1"},
|
||||
{"id", "call00001"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
|
|
@ -173,19 +183,18 @@ caps caps_get(jinja::program & prog) {
|
|||
{"arg", "value"}
|
||||
}}
|
||||
}}
|
||||
},
|
||||
{
|
||||
{"id", "call2"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool2"},
|
||||
{"arguments", {
|
||||
{"arg", "value"}
|
||||
}}
|
||||
}}
|
||||
}
|
||||
})}
|
||||
},
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"content", "Tool response"},
|
||||
{"tool_call_id", "call00001"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "The tool response was 'tool response'"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"},
|
||||
|
|
@ -199,7 +208,7 @@ caps caps_get(jinja::program & prog) {
|
|||
{"name", "tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool"},
|
||||
{"name", "tool1"},
|
||||
{"description", "Tool description"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
|
|
@ -224,6 +233,7 @@ caps caps_get(jinja::program & prog) {
|
|||
|
||||
auto & tool_name = tools->at(0)->at("function")->at("name");
|
||||
caps_print_stats(tool_name, "tools[0].function.name");
|
||||
caps_print_stats(tools, "tools");
|
||||
if (!tool_name->stats.used) {
|
||||
result.supports_tools = false;
|
||||
}
|
||||
|
|
@ -233,6 +243,93 @@ caps caps_get(jinja::program & prog) {
|
|||
if (!tool_calls->stats.used) {
|
||||
result.supports_tool_calls = false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support");
|
||||
|
||||
// case: tools support: parallel calls
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"},
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", ""}, // Some templates expect content to be empty with tool calls
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call00001"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"arguments", {
|
||||
{"arg", "value"}
|
||||
}}
|
||||
}}
|
||||
},
|
||||
{
|
||||
{"id", "call00002"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"arguments", {
|
||||
{"arg", "value"}
|
||||
}}
|
||||
}}
|
||||
}
|
||||
})}
|
||||
},
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"content", "Tool response"},
|
||||
{"tool_call_id", "call00001"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "The tool response was 'tool response'"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"},
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
{
|
||||
{"name", "tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool1"},
|
||||
{"description", "Tool description"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"arg", {
|
||||
{"type", "string"},
|
||||
{"description", "Arg description"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({ "arg" })},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & /*tools*/) {
|
||||
if (!success) {
|
||||
result.supports_parallel_tool_calls = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto & tool_calls = messages->at(1)->at("tool_calls");;
|
||||
caps_print_stats(tool_calls, "messages[1].tool_calls");
|
||||
|
||||
// check for second tool call usage
|
||||
auto & tool_call_1 = tool_calls->at(1)->at("function");
|
||||
|
|
@ -243,6 +340,8 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
caps_try_execute(
|
||||
prog,
|
||||
|
|
|
|||
|
|
@ -114,8 +114,10 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
|
||||
// Logical operators
|
||||
if (op.value == "and") {
|
||||
JJ_DEBUG("Executing logical test: %s AND %s", left->type().c_str(), right->type().c_str());
|
||||
return left_val->as_bool() ? right->execute(ctx) : std::move(left_val);
|
||||
} else if (op.value == "or") {
|
||||
JJ_DEBUG("Executing logical test: %s OR %s", left->type().c_str(), right->type().c_str());
|
||||
return left_val->as_bool() ? std::move(left_val) : right->execute(ctx);
|
||||
}
|
||||
|
||||
|
|
@ -838,7 +840,7 @@ value call_expression::execute_impl(context & ctx) {
|
|||
for (auto & arg_stmt : this->args) {
|
||||
auto arg_val = arg_stmt->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.push_back(std::move(arg_val));
|
||||
args.push_back(arg_val);
|
||||
}
|
||||
// execute callee
|
||||
value callee_val = callee->execute(ctx);
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@
|
|||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
if (separator_rule.empty()) {
|
||||
if (min_items == 1 && !has_max) {
|
||||
return item_rule + "+";
|
||||
} else if (min_items == 0 && !has_max) {
|
||||
return item_rule + "*";
|
||||
} else {
|
||||
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
|
||||
}
|
||||
if (min_items == 0 && !has_max) {
|
||||
return item_rule + "*";
|
||||
}
|
||||
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
|
||||
}
|
||||
|
||||
auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
|
||||
|
|
@ -41,7 +41,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
return result;
|
||||
}
|
||||
|
||||
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
static void build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||
|
||||
|
|
@ -128,14 +128,14 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
|
|||
if (has_min && has_max) {
|
||||
if (min_value < 0 && max_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
|
||||
build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
|
||||
out << ")";
|
||||
return;
|
||||
}
|
||||
|
||||
if (min_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
|
||||
build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
|
||||
out << ") | ";
|
||||
min_value = 0;
|
||||
}
|
||||
|
|
@ -159,7 +159,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
|
|||
if (has_min) {
|
||||
if (min_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
out << ") | [0] | [1-9] ";
|
||||
more_digits(0, decimals_left - 1);
|
||||
} else if (min_value == 0) {
|
||||
|
|
@ -194,7 +194,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
|
|||
}
|
||||
digit_range(c, c);
|
||||
out << " (";
|
||||
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||
build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||
out << ")";
|
||||
if (c < '9') {
|
||||
out << " | ";
|
||||
|
|
@ -213,10 +213,10 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string
|
|||
more_digits(0, less_decimals);
|
||||
out << " | ";
|
||||
}
|
||||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||
build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||
} else {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||
build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||
out << ")";
|
||||
}
|
||||
return;
|
||||
|
|
@ -232,7 +232,7 @@ struct BuiltinRule {
|
|||
std::vector<std::string> deps;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
||||
{"decimal-part", {"[0-9]{1,16}", {}}},
|
||||
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
||||
|
|
@ -247,7 +247,7 @@ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
|||
{"null", {"\"null\" space", {}}},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
||||
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
||||
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
||||
|
|
@ -260,22 +260,26 @@ static bool is_reserved_name(const std::string & name) {
|
|||
static const std::unordered_set<std::string> RESERVED_NAMES = [] {
|
||||
std::unordered_set<std::string> s;
|
||||
s.insert("root");
|
||||
for (const auto & p : PRIMITIVE_RULES) s.insert(p.first);
|
||||
for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first);
|
||||
for (const auto & p : PRIMITIVE_RULES) {
|
||||
s.insert(p.first);
|
||||
}
|
||||
for (const auto & p : STRING_FORMAT_RULES) {
|
||||
s.insert(p.first);
|
||||
}
|
||||
return s;
|
||||
}();
|
||||
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
|
||||
}
|
||||
|
||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
static std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
static std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||
static std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
static std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||
};
|
||||
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||
static std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
static std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||
|
||||
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
||||
std::smatch match;
|
||||
|
|
@ -322,19 +326,19 @@ private:
|
|||
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
||||
_rules[esc_name] = rule;
|
||||
return esc_name;
|
||||
} else {
|
||||
int i = 0;
|
||||
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
||||
i++;
|
||||
}
|
||||
std::string key = esc_name + std::to_string(i);
|
||||
_rules[key] = rule;
|
||||
return key;
|
||||
}
|
||||
int i = 0;
|
||||
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
||||
i++;
|
||||
}
|
||||
std::string key = esc_name + std::to_string(i);
|
||||
_rules[key] = rule;
|
||||
return key;
|
||||
}
|
||||
|
||||
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
|
||||
std::vector<std::string> rules;
|
||||
rules.reserve(alt_schemas.size());
|
||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
||||
}
|
||||
|
|
@ -398,6 +402,7 @@ private:
|
|||
flush_literal();
|
||||
|
||||
std::vector<std::string> results;
|
||||
results.reserve(ret.size());
|
||||
for (const auto & item : ret) {
|
||||
results.push_back(to_rule(item));
|
||||
}
|
||||
|
|
@ -551,7 +556,7 @@ private:
|
|||
TrieNode() : is_end_of_string(false) {}
|
||||
|
||||
void insert(const std::string & string) {
|
||||
auto node = this;
|
||||
auto *node = this;
|
||||
for (char c : string) {
|
||||
node = &node->children[c];
|
||||
}
|
||||
|
|
@ -676,7 +681,7 @@ private:
|
|||
if (ks.empty()) {
|
||||
return res;
|
||||
}
|
||||
std::string k = ks[0];
|
||||
const std::string& k = ks[0];
|
||||
std::string kv_rule_name = prop_kv_rule_names[k];
|
||||
std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
|
||||
if (first_is_optional) {
|
||||
|
|
@ -779,7 +784,7 @@ public:
|
|||
std::string pointer = ref.substr(ref.find('#') + 1);
|
||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
const std::string& sel = tokens[i];
|
||||
if (target.is_object() && target.contains(sel)) {
|
||||
target = target[sel];
|
||||
} else if (target.is_array()) {
|
||||
|
|
@ -802,7 +807,7 @@ public:
|
|||
_refs[ref] = target;
|
||||
}
|
||||
} else {
|
||||
for (auto & kv : n.items()) {
|
||||
for (const auto & kv : n.items()) {
|
||||
visit_refs(kv.value());
|
||||
}
|
||||
}
|
||||
|
|
@ -812,7 +817,7 @@ public:
|
|||
visit_refs(schema);
|
||||
}
|
||||
|
||||
std::string _generate_constant_rule(const json & value) {
|
||||
static std::string _generate_constant_rule(const json & value) {
|
||||
return format_literal(value.dump());
|
||||
}
|
||||
|
||||
|
|
@ -823,10 +828,12 @@ public:
|
|||
|
||||
if (schema.contains("$ref")) {
|
||||
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
|
||||
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
||||
}
|
||||
if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
||||
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
|
||||
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
|
||||
} else if (schema_type.is_array()) {
|
||||
}
|
||||
if (schema_type.is_array()) {
|
||||
std::vector<json> schema_types;
|
||||
for (const auto & t : schema_type) {
|
||||
json schema_copy(schema);
|
||||
|
|
@ -834,15 +841,18 @@ public:
|
|||
schema_types.push_back(schema_copy);
|
||||
}
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
} else if (schema.contains("const")) {
|
||||
}
|
||||
if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
||||
} else if (schema.contains("enum")) {
|
||||
}
|
||||
if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
} else if ((schema_type.is_null() || schema_type == "object")
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||
std::unordered_set<std::string> required;
|
||||
|
|
@ -863,11 +873,12 @@ public:
|
|||
_build_object_rule(
|
||||
properties, required, name,
|
||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||
std::unordered_set<std::string> required;
|
||||
std::vector<std::pair<std::string, json>> properties;
|
||||
std::map<std::string, size_t> enum_values;
|
||||
std::string hybrid_name = name;
|
||||
const std::string& hybrid_name = name;
|
||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||
if (comp_schema.contains("$ref")) {
|
||||
add_component(_refs[comp_schema["$ref"]], is_required);
|
||||
|
|
@ -890,9 +901,9 @@ public:
|
|||
// todo warning
|
||||
}
|
||||
};
|
||||
for (auto & t : schema["allOf"]) {
|
||||
for (const auto & t : schema["allOf"]) {
|
||||
if (t.contains("anyOf")) {
|
||||
for (auto & tt : t["anyOf"]) {
|
||||
for (const auto & tt : t["anyOf"]) {
|
||||
add_component(tt, false);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -911,7 +922,8 @@ public:
|
|||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||
if (items.is_array()) {
|
||||
std::string rule = "\"[\" space ";
|
||||
|
|
@ -923,27 +935,31 @@ public:
|
|||
}
|
||||
rule += " \"]\" space";
|
||||
return _add_rule(rule_name, rule);
|
||||
} else {
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
}
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
|
||||
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
|
||||
auto prim_name = schema_format + "-string";
|
||||
return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
|
||||
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
|
||||
}
|
||||
if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
|
||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
}
|
||||
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||
if (schema.contains("minimum")) {
|
||||
|
|
@ -958,19 +974,24 @@ public:
|
|||
}
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
_build_min_max_int(min_value, max_value, out);
|
||||
build_min_max_int(min_value, max_value, out);
|
||||
out << ") space";
|
||||
return _add_rule(rule_name, out.str());
|
||||
} else if (schema.empty() || schema_type == "object") {
|
||||
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
|
||||
} else {
|
||||
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
||||
_errors.push_back("Unrecognized schema: " + schema.dump());
|
||||
return "";
|
||||
}
|
||||
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
|
||||
}
|
||||
if (schema.empty() || schema_type == "object") {
|
||||
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
|
||||
}
|
||||
if (schema_type.is_null() && schema.is_object()) {
|
||||
// No type constraint and no recognized structural keywords (e.g. {"description": "..."}).
|
||||
// Per JSON Schema semantics this is equivalent to {} and accepts any value.
|
||||
return _add_rule(rule_name, _add_primitive("value", PRIMITIVE_RULES.at("value")));
|
||||
}
|
||||
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
||||
_errors.push_back("Unrecognized schema: " + schema.dump());
|
||||
return "";
|
||||
}
|
||||
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
|
||||
}
|
||||
|
||||
void check_errors() {
|
||||
|
|
@ -985,7 +1006,7 @@ public:
|
|||
std::string format_grammar() {
|
||||
std::stringstream ss;
|
||||
for (const auto & kv : _rules) {
|
||||
ss << kv.first << " ::= " << kv.second << std::endl;
|
||||
ss << kv.first << " ::= " << kv.second << '\n';
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
#include "common.h"
|
||||
#include "peg-parser.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
|
@ -34,8 +35,7 @@ static bool is_hex_digit(const char c) {
|
|||
// This is used in common_peg_until_parser and to build a GBNF exclusion grammar
|
||||
struct trie {
|
||||
struct node {
|
||||
size_t depth = 0;
|
||||
std::map<unsigned char, size_t> children;
|
||||
std::map<uint32_t, size_t> children; // Use uint32_t to store Unicode codepoints
|
||||
bool is_word;
|
||||
};
|
||||
|
||||
|
|
@ -55,15 +55,22 @@ struct trie {
|
|||
size_t current = 0; // Start at root
|
||||
size_t pos = start_pos;
|
||||
|
||||
// LOG_DBG("%s: checking at pos %zu, sv='%s'\n", __func__, start_pos, std::string(sv).c_str());
|
||||
|
||||
while (pos < sv.size()) {
|
||||
auto it = nodes[current].children.find(sv[pos]);
|
||||
auto result = common_parse_utf8_codepoint(sv, pos);
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto it = nodes[current].children.find(result.codepoint);
|
||||
if (it == nodes[current].children.end()) {
|
||||
// Can't continue matching
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
current = it->second;
|
||||
pos++;
|
||||
pos += result.bytes_consumed;
|
||||
|
||||
// Check if we've matched a complete word
|
||||
if (nodes[current].is_word) {
|
||||
|
|
@ -82,22 +89,22 @@ struct trie {
|
|||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::string prefix;
|
||||
std::string next_chars;
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::string prefix;
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::string & prefix, std::vector<prefix_and_next> & out) {
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::string chars;
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
|
|
@ -107,7 +114,7 @@ struct trie {
|
|||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
unsigned char ch = p.first;
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
|
|
@ -123,11 +130,19 @@ struct trie {
|
|||
|
||||
void insert(const std::string & word) {
|
||||
size_t current = 0;
|
||||
for (unsigned char ch : word) {
|
||||
size_t pos = 0;
|
||||
while (pos < word.length()) {
|
||||
auto result = common_parse_utf8_codepoint(word, pos);
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
break;
|
||||
}
|
||||
|
||||
uint32_t ch = result.codepoint;
|
||||
pos += result.bytes_consumed;
|
||||
|
||||
auto it = nodes[current].children.find(ch);
|
||||
if (it == nodes[current].children.end()) {
|
||||
size_t child = create_node();
|
||||
nodes[child].depth = nodes[current].depth + 1;
|
||||
nodes[current].children[ch] = child;
|
||||
current = child;
|
||||
} else {
|
||||
|
|
@ -286,6 +301,32 @@ struct parser_executor {
|
|||
parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start)
|
||||
: arena(arena), ctx(ctx), start_pos(start) {}
|
||||
|
||||
std::string debug_indent() const { return std::string(ctx.parse_depth * 2, ' '); }
|
||||
|
||||
std::string debug_input_snippet(size_t pos, size_t len = 60) const {
|
||||
if (pos >= ctx.input.size()) {
|
||||
return "<EOF>";
|
||||
}
|
||||
auto snippet = ctx.input.substr(pos, len);
|
||||
// Escape newlines for display
|
||||
std::string result;
|
||||
for (char c : snippet) {
|
||||
if (c == '\n') {
|
||||
result += "\\n";
|
||||
} else if (c == '\r') {
|
||||
result += "\\r";
|
||||
} else if (c == '\t') {
|
||||
result += "\\t";
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
if (pos + len < ctx.input.size()) {
|
||||
result += "...";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos);
|
||||
}
|
||||
|
|
@ -323,12 +364,39 @@ struct parser_executor {
|
|||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_sequence_parser & p) {
|
||||
if (ctx.debug) {
|
||||
LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
ctx.parse_depth++;
|
||||
|
||||
auto pos = start_pos;
|
||||
std::vector<common_peg_ast_id> nodes;
|
||||
|
||||
for (const auto & child_id : p.children) {
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end);
|
||||
}
|
||||
|
||||
if (result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.is_partial && result.end >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE (child failed at end)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end,
|
||||
std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end);
|
||||
}
|
||||
|
||||
|
|
@ -337,28 +405,65 @@ struct parser_executor {
|
|||
}
|
||||
|
||||
if (result.need_more_input()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
|
||||
}
|
||||
|
||||
pos = result.end;
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_choice_parser & p) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
ctx.parse_depth++;
|
||||
|
||||
auto pos = start_pos;
|
||||
for (const auto & child_id : p.children) {
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type));
|
||||
}
|
||||
if (!result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(),
|
||||
common_peg_parse_result_type_name(result.type), i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_repetition_parser & p) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count);
|
||||
}
|
||||
ctx.parse_depth++;
|
||||
|
||||
auto pos = start_pos;
|
||||
int match_count = 0;
|
||||
std::vector<common_peg_ast_id> nodes;
|
||||
|
|
@ -366,14 +471,26 @@ struct parser_executor {
|
|||
// Try to match up to max_count times (or unlimited if max_count is -1)
|
||||
while (p.max_count == -1 || match_count < p.max_count) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
auto result = arena.parse(p.child, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size());
|
||||
fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str());
|
||||
}
|
||||
|
||||
if (result.success()) {
|
||||
// Prevent infinite loop on empty matches
|
||||
if (result.end == pos) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -391,21 +508,43 @@ struct parser_executor {
|
|||
nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end());
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(),
|
||||
match_count, nodes.size());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
|
||||
}
|
||||
|
||||
// Child failed - stop trying
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if we got enough matches
|
||||
if (p.min_count > 0 && match_count < p.min_count) {
|
||||
ctx.parse_depth--;
|
||||
if (pos >= ctx.input.size() && ctx.is_partial) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(),
|
||||
match_count, p.min_count);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count,
|
||||
p.min_count);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count,
|
||||
nodes.size());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
|
||||
|
|
@ -434,7 +573,7 @@ struct parser_executor {
|
|||
|
||||
common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const {
|
||||
// Parse a single UTF-8 codepoint (not just a single byte)
|
||||
auto result = parse_utf8_codepoint(ctx.input, start_pos);
|
||||
auto result = common_parse_utf8_codepoint(ctx.input, start_pos);
|
||||
|
||||
if (result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
|
|
@ -468,7 +607,7 @@ struct parser_executor {
|
|||
|
||||
// Try to match up to max_count times (or unlimited if max_count is -1)
|
||||
while (p.max_count == -1 || match_count < p.max_count) {
|
||||
auto result = parse_utf8_codepoint(ctx.input, pos);
|
||||
auto result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (match_count >= p.min_count) {
|
||||
|
|
@ -537,6 +676,7 @@ struct parser_executor {
|
|||
|
||||
switch (ctx.input[pos]) {
|
||||
case '"':
|
||||
case '\'':
|
||||
case '\\':
|
||||
case '/':
|
||||
case 'b':
|
||||
|
|
@ -589,7 +729,49 @@ struct parser_executor {
|
|||
return result;
|
||||
}
|
||||
} else {
|
||||
auto utf8_result = parse_utf8_codepoint(ctx.input, pos);
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INVALID) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
pos += utf8_result.bytes_consumed;
|
||||
}
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_partial) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
|
||||
auto pos = start_pos;
|
||||
|
||||
// Parse string content (without quotes)
|
||||
while (pos < ctx.input.size()) {
|
||||
char c = ctx.input[pos];
|
||||
|
||||
if (c == '\'') {
|
||||
// Found closing quote - success (don't consume it)
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
|
|
@ -621,7 +803,7 @@ struct parser_executor {
|
|||
size_t last_valid_pos = start_pos;
|
||||
|
||||
while (pos < ctx.input.size()) {
|
||||
auto utf8_result = parse_utf8_codepoint(ctx.input, pos);
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
// Incomplete UTF-8 sequence
|
||||
|
|
@ -694,6 +876,9 @@ struct parser_executor {
|
|||
|
||||
common_peg_parse_result operator()(const common_peg_tag_parser & p) {
|
||||
// Parse the child
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str());
|
||||
}
|
||||
auto result = arena.parse(p.child, ctx, start_pos);
|
||||
|
||||
if (!result.fail()) {
|
||||
|
|
@ -755,6 +940,31 @@ common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) {
|
|||
return id;
|
||||
}
|
||||
|
||||
static void bfs_node(common_peg_ast_arena &arena, std::ostringstream & oss, const common_peg_ast_node & node, int indent) {
|
||||
for (int i = 0; i < indent; i++) {
|
||||
oss << " ";
|
||||
}
|
||||
oss << "NODE " << node.id;
|
||||
if (!node.rule.empty()) {
|
||||
oss << " (rule " << node.rule << ")";
|
||||
}
|
||||
if (!node.tag.empty()) {
|
||||
oss << " (tag " << node.tag << ")";
|
||||
}
|
||||
oss << " ['" << node.text << "']\n";
|
||||
for (const auto child : node.children) {
|
||||
bfs_node(arena, oss, arena.get(child), indent + 1);
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_peg_ast_arena::dump() {
|
||||
std::ostringstream oss;
|
||||
for (auto & node : nodes_) {
|
||||
bfs_node(*this, oss, node, 0);
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void common_peg_arena::resolve_refs() {
|
||||
// Walk through all parsers and replace refs with their corresponding rule IDs
|
||||
for (auto & parser : parsers_) {
|
||||
|
|
@ -786,6 +996,7 @@ void common_peg_arena::resolve_refs() {
|
|||
std::is_same_v<T, common_peg_until_parser> ||
|
||||
std::is_same_v<T, common_peg_literal_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
|
||||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser>) {
|
||||
|
|
@ -803,9 +1014,21 @@ void common_peg_arena::resolve_refs() {
|
|||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
}
|
||||
visited.insert(id);
|
||||
|
||||
const auto & parser = parsers_.at(id);
|
||||
|
||||
return std::visit([this](const auto & p) -> std::string {
|
||||
return std::visit([this, &visited](const auto & p) -> std::string {
|
||||
using T = std::decay_t<decltype(p)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, common_peg_epsilon_parser>) {
|
||||
|
|
@ -819,24 +1042,27 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
|||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
std::vector<std::string> parts;
|
||||
for (const auto & child : p.children) {
|
||||
parts.push_back(dump(child));
|
||||
parts.push_back(dump_impl(child, visited));
|
||||
}
|
||||
return "Sequence(" + string_join(parts, ", ") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_choice_parser>) {
|
||||
std::vector<std::string> parts;
|
||||
for (const auto & child : p.children) {
|
||||
parts.push_back(dump(child));
|
||||
parts.push_back(dump_impl(child, visited));
|
||||
}
|
||||
return "Choice(" + string_join(parts, ", ") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_repetition_parser>) {
|
||||
if (p.max_count == -1) {
|
||||
return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)";
|
||||
return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) +
|
||||
", unbounded)";
|
||||
}
|
||||
return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
|
||||
return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_and_parser>) {
|
||||
return "And(" + dump(p.child) + ")";
|
||||
return "And(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_not_parser>) {
|
||||
return "Not(" + dump(p.child) + ")";
|
||||
return "Not(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
|
|
@ -848,14 +1074,20 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
|||
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return "JsonString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return "PythonDictString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return "Until(" + string_join(p.delimiters, " | ") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")";
|
||||
return "Schema(" + dump_impl(p.child, visited) + ", " + (p.schema ? p.schema->dump() : "null") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
return "Rule(" + p.name + ", " + dump(p.child) + ")";
|
||||
return "Rule(" + p.name + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ref_parser>) {
|
||||
return "Ref(" + p.name + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_tag_parser>) {
|
||||
return "Tag(" + p.tag + ", " + dump(p.child) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump(p.child) + ")";
|
||||
} else {
|
||||
return "Unknown";
|
||||
}
|
||||
|
|
@ -1054,7 +1286,54 @@ common_peg_arena common_peg_parser_builder::build() {
|
|||
return std::move(arena_);
|
||||
}
|
||||
|
||||
// String primitives
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::double_quoted_string() {
|
||||
return rule("dq-string",
|
||||
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string() {
|
||||
return rule("sq-string",
|
||||
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::flexible_string() {
|
||||
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
}
|
||||
|
||||
// Generic helpers for object/array structure
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
|
||||
const common_peg_parser & string_parser,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, string_parser, value_parser]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
|
||||
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
|
||||
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, value_parser]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
|
||||
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
|
||||
});
|
||||
}
|
||||
|
||||
// JSON parsers
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_number() {
|
||||
return rule("json-number", [this]() {
|
||||
auto digit1_9 = chars("[1-9]", 1, 1);
|
||||
|
|
@ -1062,7 +1341,11 @@ common_peg_parser common_peg_parser_builder::json_number() {
|
|||
auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})});
|
||||
auto frac = sequence({literal("."), digits});
|
||||
auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits});
|
||||
return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()});
|
||||
// Negative lookahead: only commit the number when the next character can't extend it.
|
||||
// At EOF in partial mode, chars returns NEED_MORE → negate propagates NEED_MORE → number not committed.
|
||||
// This prevents premature commits of partial numbers (e.g. "3" when "3.14" is incoming).
|
||||
auto not_number_continuation = negate(chars("[0-9.eE+-]", 1, 1));
|
||||
return sequence({ optional(literal("-")), int_part, optional(frac), optional(exp), not_number_continuation, space() });
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -1085,36 +1368,11 @@ common_peg_parser common_peg_parser_builder::json_null() {
|
|||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_object() {
|
||||
return rule("json-object", [this]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
|
||||
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
|
||||
return sequence({
|
||||
literal("{"),
|
||||
ws,
|
||||
choice({
|
||||
literal("}"),
|
||||
sequence({members, ws, literal("}")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
return generic_object("json-object", json_string(), json());
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
choice({
|
||||
literal("]"),
|
||||
sequence({elements, ws, literal("]")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
return generic_array("json-array", json());
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json() {
|
||||
|
|
@ -1130,8 +1388,40 @@ common_peg_parser common_peg_parser_builder::json() {
|
|||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
|
||||
common_peg_parser common_peg_parser_builder::python_string() {
|
||||
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_number() {
|
||||
return json_number();
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_bool() {
|
||||
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_null() {
|
||||
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_dict() {
|
||||
return generic_object("python-dict", python_string(), python_value());
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_array() {
|
||||
return generic_array("python-array", python_value());
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_value() {
|
||||
return rule("python-value", [this]() {
|
||||
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::marker() {
|
||||
auto sharp_bracket_parser = literal("<") + until(">") + literal(">");
|
||||
auto square_bracket_parser = literal("[") + until("]") + literal("]");
|
||||
return choice({ sharp_bracket_parser, square_bracket_parser });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) {
|
||||
|
|
@ -1145,17 +1435,54 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
static std::string gbnf_escape_char_class(char c) {
|
||||
switch (c) {
|
||||
case '\n': return "\\n";
|
||||
case '\t': return "\\t";
|
||||
case '\r': return "\\r";
|
||||
case '\\': return "\\\\";
|
||||
case ']': return "\\]";
|
||||
case '[': return "\\[";
|
||||
default: return std::string(1, c);
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
}
|
||||
// Escape whitespace control characters
|
||||
if (c == '\n') {
|
||||
return "\\n";
|
||||
}
|
||||
if (c == '\t') {
|
||||
return "\\t";
|
||||
}
|
||||
if (c == '\r') {
|
||||
return "\\r";
|
||||
}
|
||||
|
||||
// Printable ASCII
|
||||
if (c >= 0x20 && c <= 0x7E) {
|
||||
return std::string(1, (char) c);
|
||||
}
|
||||
|
||||
// Hex escape
|
||||
char buf[16];
|
||||
const char * hex = "0123456789ABCDEF";
|
||||
|
||||
if (c <= 0xFF) {
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'x';
|
||||
buf[2] = hex[(c >> 4) & 0xF];
|
||||
buf[3] = hex[c & 0xF];
|
||||
buf[4] = '\0';
|
||||
} else if (c <= 0xFFFF) {
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = hex[(c >> 12) & 0xF];
|
||||
buf[3] = hex[(c >> 8) & 0xF];
|
||||
buf[4] = hex[(c >> 4) & 0xF];
|
||||
buf[5] = hex[c & 0xF];
|
||||
buf[6] = '\0';
|
||||
} else {
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'U';
|
||||
for (int i = 0; i < 8; i++) {
|
||||
buf[2 + i] = hex[(c >> ((7 - i) * 4)) & 0xF];
|
||||
}
|
||||
buf[10] = '\0';
|
||||
}
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
|
|
@ -1173,12 +1500,12 @@ static std::string gbnf_excluding_pattern(const std::vector<std::string> & strin
|
|||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
for (const auto & ch : chars) {
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
pattern += gbnf_format_literal(pre) + " [^" + cls + "]";
|
||||
pattern += gbnf_format_literal(common_unicode_cpts_to_utf8(pre)) + " [^" + cls + "]";
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
}
|
||||
|
|
@ -1208,7 +1535,8 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
|||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
// These parsers do not have any children
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
for (auto child : p.children) {
|
||||
|
|
@ -1346,6 +1674,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
|||
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
|
|
@ -1477,6 +1807,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
|||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return json{{"type", "json_string"}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return json{{ "type", "python_dict_string" }};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return json{{"type", "until"}, {"delimiters", p.delimiters}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
|
|
@ -1606,6 +1938,9 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
|||
if (type == "json_string") {
|
||||
return common_peg_json_string_parser{};
|
||||
}
|
||||
if (type == "python_dict_string") {
|
||||
return common_peg_python_dict_string_parser{};
|
||||
}
|
||||
if (type == "until") {
|
||||
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {
|
||||
throw std::runtime_error("until parser missing or invalid 'delimiters' field");
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
|
|
@ -111,6 +112,8 @@ class common_peg_ast_arena {
|
|||
|
||||
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
||||
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
||||
|
||||
std::string dump();
|
||||
};
|
||||
|
||||
struct common_peg_parse_result {
|
||||
|
|
@ -139,6 +142,7 @@ struct common_peg_parse_result {
|
|||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
bool debug = false; // Enable debug output for parser tracing
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
|
@ -207,6 +211,7 @@ struct common_peg_chars_parser {
|
|||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
struct common_peg_python_dict_string_parser {};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
|
|
@ -255,6 +260,7 @@ using common_peg_parser_variant = std::variant<
|
|||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_python_dict_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
|
|
@ -299,6 +305,8 @@ class common_peg_arena {
|
|||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
|
||||
|
|
@ -311,6 +319,10 @@ class common_peg_parser_builder {
|
|||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
// Generic helpers for building object/array structures with configurable string/value parsers.
|
||||
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
|
||||
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
|
|
@ -404,6 +416,21 @@ class common_peg_parser_builder {
|
|||
// S -> A{n}
|
||||
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
||||
|
||||
// Matches a double-quoted string: '"' content '"' space
|
||||
common_peg_parser double_quoted_string();
|
||||
|
||||
// Matches a single-quoted string: "'" content "'" space
|
||||
common_peg_parser single_quoted_string();
|
||||
|
||||
// Matches a string that accepts both double-quoted and single-quoted styles.
|
||||
common_peg_parser flexible_string();
|
||||
|
||||
// Matches double-quoted string content without the surrounding quotes.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches single-quoted string content without the surrounding quotes.
|
||||
common_peg_parser single_quoted_string_content();
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
common_peg_parser json();
|
||||
|
|
@ -414,14 +441,24 @@ class common_peg_parser_builder {
|
|||
common_peg_parser json_bool();
|
||||
common_peg_parser json_null();
|
||||
|
||||
// Matches JSON string content without the surrounding quotes.
|
||||
// Useful for extracting content within a JSON string.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches a JSON object member with a key and associated parser as the
|
||||
// value.
|
||||
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
||||
|
||||
// Creates a complete Python format parser supporting dicts, arrays, strings, numbers, booleans, and None.
|
||||
// Differs from JSON: uses True/False/None, accepts both single and double-quoted strings.
|
||||
// value -> dict | array | string | number | True | False | None
|
||||
common_peg_parser python_value();
|
||||
common_peg_parser python_dict();
|
||||
common_peg_parser python_string();
|
||||
common_peg_parser python_array();
|
||||
common_peg_parser python_number();
|
||||
common_peg_parser python_bool();
|
||||
common_peg_parser python_null();
|
||||
|
||||
// A marker, i.e. text delimited by a pair of <> or []
|
||||
common_peg_parser marker();
|
||||
|
||||
// Wraps a parser with JSON schema metadata for grammar generation.
|
||||
// Used internally to convert JSON schemas to GBNF grammar rules.
|
||||
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
#include "unicode.h"
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
size_t utf8_sequence_length(unsigned char first_byte) {
|
||||
size_t common_utf8_sequence_length(unsigned char first_byte) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||
if (offset >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
|
|
@ -62,3 +66,43 @@ utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
|||
// Invalid first byte
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < cps.size(); ++i) {
|
||||
result.append(common_unicode_cpt_to_utf8(cps[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_unicode_cpt_to_utf8(uint32_t cpt) {
|
||||
std::string result;
|
||||
|
||||
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
|
||||
result.push_back(cpt);
|
||||
return result;
|
||||
}
|
||||
if (0x80 <= cpt && cpt <= 0x7ff) {
|
||||
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x800 <= cpt && cpt <= 0xffff) {
|
||||
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
|
||||
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("invalid codepoint");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include <cstdint>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
// UTF-8 parsing utilities for streaming-aware unicode support
|
||||
|
||||
|
|
@ -16,7 +18,10 @@ struct utf8_parse_result {
|
|||
|
||||
// Determine the expected length of a UTF-8 sequence from its first byte
|
||||
// Returns 0 for invalid first bytes
|
||||
size_t utf8_sequence_length(unsigned char first_byte);
|
||||
size_t common_utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
|
||||
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps);
|
||||
std::string common_unicode_cpt_to_utf8(uint32_t cpt);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,525 @@
|
|||
# Auto-Parser Architecture
|
||||
|
||||
The auto-parser automatically analyzes chat templates to determine how to parse model outputs, including content, reasoning, and tool calls.
|
||||
|
||||
## Overview
|
||||
|
||||
The unified auto-parser uses a pure differential, compositional approach (inspired by the `git diff` algorithm) to analyze chat templates:
|
||||
|
||||
**Core Philosophy**:
|
||||
|
||||
- **Minimize Hardcoded Patterns**: All markers extracted through template comparison (the only heuristic is JSON detection to distinguish `JSON_NATIVE` from tag-based formats)
|
||||
- **Compositional Architecture**: Separate analyzer structs for reasoning, content, and tools — each responsible for its own analysis and parser construction
|
||||
|
||||
**Analysis + Parser Building in Two Steps**:
|
||||
|
||||
1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs
|
||||
2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar
|
||||
|
||||
## Data Structures
|
||||
|
||||
All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h).
|
||||
|
||||
### Top-Level: `autoparser` (main analyzer and generator)
|
||||
|
||||
[common/chat-auto-parser.h:367-388](common/chat-auto-parser.h#L367-L388) — top-level analysis result aggregating `jinja_caps`, `reasoning`, `content`, and `tools` sub-analyses, plus `preserved_tokens` (union of all non-empty markers).
|
||||
|
||||
### `analyze_reasoning`
|
||||
|
||||
[common/chat-auto-parser.h:254-274](common/chat-auto-parser.h#L254-L274) — reasoning analysis result: `mode` enum, `start` marker (e.g. `<think>`), and `end` marker (e.g. `</think>`).
|
||||
|
||||
### `analyze_content`
|
||||
|
||||
[common/chat-auto-parser.h:280-295](common/chat-auto-parser.h#L280-L295) — content analysis result: `mode` enum, `start`/`end` markers, and `requires_nonnull_content` flag.
|
||||
|
||||
### `analyze_tools` and its sub-structs
|
||||
|
||||
- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`)
|
||||
- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names
|
||||
- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator`
|
||||
- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values
|
||||
- [common/chat-auto-parser.h:301-361](common/chat-auto-parser.h#L301-L361) — `analyze_tools`: aggregates the four sub-structs above
|
||||
|
||||
### Enums
|
||||
|
||||
**`reasoning_mode`**: How the template handles reasoning/thinking blocks.
|
||||
|
||||
| Value | Description |
|
||||
|-----------------|-----------------------------------------------------------------------------------|
|
||||
| `NONE` | No reasoning markers detected |
|
||||
| `TAG_BASED` | Standard tag-based: `<think>...</think>` |
|
||||
| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) |
|
||||
| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` |
|
||||
| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start |
|
||||
| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content |
|
||||
|
||||
**`content_mode`**: How the template wraps assistant content.
|
||||
|
||||
| Value | Description |
|
||||
|--------------------------|----------------------------------------------------------------|
|
||||
| `PLAIN` | No content markers |
|
||||
| `ALWAYS_WRAPPED` | Content always wrapped: `<response>...</response>` |
|
||||
| `WRAPPED_WITH_REASONING` | Content wrapped only when reasoning is present |
|
||||
|
||||
**`tool_format`**: Classification of tool call structure.
|
||||
|
||||
| Value | Description |
|
||||
|------------------|------------------------------------------------------------------|
|
||||
| `NONE` | No tool support detected |
|
||||
| `JSON_NATIVE` | Pure JSON: `{"name": "X", "arguments": {...}}` |
|
||||
| `TAG_WITH_JSON` | Tag-based with JSON args: `<function=X>{...}</function>` |
|
||||
| `TAG_WITH_TAGGED`| Tag-based with tagged args: `<param=key>value</param>` |
|
||||
|
||||
**`call_id_position`**: Where call IDs appear in tag-based formats.
|
||||
|
||||
| Value | Description |
|
||||
|--------------------------|----------------------------------------------|
|
||||
| `NONE` | No call ID support detected |
|
||||
| `PRE_FUNC_NAME` | Before function name |
|
||||
| `BETWEEN_FUNC_AND_ARGS` | Between function name and arguments |
|
||||
| `POST_ARGS` | After arguments |
|
||||
|
||||
## Tool Calling Formats
|
||||
|
||||
### JSON_NATIVE
|
||||
|
||||
**Structure**: The entire tool call (function name, arguments, values) is in JSON format. Optional enclosing tags around the section.
|
||||
|
||||
**Detection**: Function name appears inside a JSON structure (quotes preceded by `{` or `:`).
|
||||
|
||||
**Examples**:
|
||||
|
||||
Standard OpenAI-style:
|
||||
|
||||
```json
|
||||
<tool_call>
|
||||
{"name": "get_weather", "arguments": {"location": "Paris", "unit": "celsius"}}
|
||||
</tool_call>
|
||||
```
|
||||
|
||||
Mistral Nemo with array wrapper:
|
||||
|
||||
```json
|
||||
[TOOL_CALLS]
|
||||
[{"name": "calculate", "arguments": {"expr": "2+2"}}]
|
||||
```
|
||||
|
||||
Function name as JSON key (Apertus style):
|
||||
|
||||
```json
|
||||
{"get_weather": {"location": "Paris"}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### TAG_WITH_JSON
|
||||
|
||||
**Structure**: Function name is outside JSON, in tag attributes or XML-style tags. Arguments are a JSON object.
|
||||
|
||||
**Detection**: Function name not in JSON, but argument names appear in JSON context.
|
||||
|
||||
**Examples**:
|
||||
|
||||
Functionary v3.1:
|
||||
|
||||
```xml
|
||||
<function=get_weather>{"location": "Paris", "unit": "celsius"}</function>
|
||||
```
|
||||
|
||||
MiniMax:
|
||||
|
||||
```xml
|
||||
<minimax:tool_call>
|
||||
<tool_name>calculate</tool_name>
|
||||
<arguments>{"expr": "2+2"}</arguments>
|
||||
</minimax:tool_call>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### TAG_WITH_TAGGED
|
||||
|
||||
**Structure**: Both function name and argument names are in XML-style tags. String values are unquoted; non-string values are JSON-formatted.
|
||||
|
||||
**Detection**: Neither function name nor argument names appear in a JSON context.
|
||||
|
||||
**Examples**:
|
||||
|
||||
Qwen/Hermes XML format:
|
||||
|
||||
```xml
|
||||
<function=get_weather>
|
||||
<param=location>Paris</param>
|
||||
<param=unit>celsius</param>
|
||||
</function>
|
||||
```
|
||||
|
||||
Mixed types:
|
||||
|
||||
```xml
|
||||
<function=calculate>
|
||||
<param=expr>2+2</param>
|
||||
<param=precision>2</param>
|
||||
<param=options>{"round": true}</param>
|
||||
</function>
|
||||
```
|
||||
|
||||
String values (`Paris`, `celsius`, `2+2`) are unquoted; `options` (object type) is JSON-formatted.
|
||||
|
||||
---
|
||||
|
||||
## Analysis Flow
|
||||
|
||||
```text
|
||||
autoparser::autoparser(tmpl)
|
||||
|
|
||||
|-- Phase 1: analyze_reasoning(tmpl, jinja_caps.supports_tool_calls)
|
||||
| |-- R1: compare_reasoning_presence() — with/without reasoning_content field
|
||||
| |-- R2: compare_thinking_enabled() — enable_thinking=false vs true
|
||||
| '-- R3: compare_reasoning_scope() — reasoning+content vs reasoning+tools
|
||||
| (only if supports_tool_calls)
|
||||
|
|
||||
|-- Phase 2: analyze_content(tmpl, reasoning)
|
||||
| '-- C1: compares content-only vs tools output and content-only vs reasoning output
|
||||
|
|
||||
|-- Phase 3: analyze_tools(tmpl, jinja_caps, reasoning)
|
||||
| (skipped entirely if !jinja_caps.supports_tool_calls)
|
||||
| |
|
||||
| |-- T1: analyze_tool_calls() — no tools vs with tools; classifies format
|
||||
| | |-- JSON path → analyze_tool_call_format_json_native()
|
||||
| | '-- tag path → analyze_tool_call_format_non_json()
|
||||
| |
|
||||
| (if format != NONE and format != JSON_NATIVE:)
|
||||
| |
|
||||
| |-- T2: check_per_call_markers() — 1 call vs 2 calls; moves section→per-call if needed
|
||||
| | (only if supports_parallel_tool_calls)
|
||||
| |
|
||||
| |-- T3: extract_function_markers() — func_alpha vs func_beta; extracts name prefix/suffix/close
|
||||
| |
|
||||
| |-- T4: analyze_arguments() — (TAG_WITH_TAGGED only)
|
||||
| | |-- A1: extract_argument_name_markers() — arg_name_A vs arg_name_B
|
||||
| | '-- A2: extract_argument_value_markers() — value "XXXX" vs "YYYY"
|
||||
| |
|
||||
| |-- T5: extract_argument_separator() — 1 arg vs 2 args; finds separator between args
|
||||
| |
|
||||
| |-- T6: extract_args_markers() — 0 args vs 1 arg; finds args container markers
|
||||
| |
|
||||
| '-- T7: extract_call_id_markers() — call_id "call00001" vs "call99999"
|
||||
|
|
||||
'-- collect_preserved_tokens() — union of all non-empty markers
|
||||
|
|
||||
'-- apply workarounds() — post-hoc patches for edge-case templates
|
||||
|
|
||||
v
|
||||
autoparser (analysis result)
|
||||
|
|
||||
v
|
||||
autoparser::peg_generator::generate_parser(tmpl, inputs, analysis)
|
||||
|-- analysis.build_parser(inputs) — builds PEG parser arena
|
||||
| |-- reasoning.build_parser(ctx) — reasoning parser (mode-dependent)
|
||||
| |-- content.build_parser(ctx) — content parser (mode-dependent)
|
||||
| '-- tools.build_parser(ctx) — tool parser (dispatches by tool_format)
|
||||
| |-- build_tool_parser_json_native()
|
||||
| |-- build_tool_parser_tag_json()
|
||||
| '-- build_tool_parser_tag_tagged()
|
||||
|
|
||||
|-- Build GBNF grammar (if tools present and trigger_marker non-empty)
|
||||
'-- Set grammar_triggers from section_start or per_call_start
|
||||
|
|
||||
v
|
||||
common_chat_params (prompt, parser, grammar, triggers, preserved_tokens)
|
||||
```
|
||||
|
||||
## Entry Point
|
||||
|
||||
The auto-parser is invoked in [common/chat.cpp:1280-1310](common/chat.cpp#L1280-L1310) in `common_chat_templates_apply_jinja`. A few specialized templates are handled first (Ministral/Magistral Large 3, GPT-OSS with `<|channel|>`, Functionary v3.2 with `>>>all`), then the auto-parser handles everything else via `autoparser::autoparser` + `peg_generator::generate_parser`.
|
||||
|
||||
## Algorithm Details
|
||||
|
||||
### Core Mechanism: Differential Comparison
|
||||
|
||||
All analysis phases use the same factorized comparison function declared in [common/chat-auto-parser-helpers.h:68](common/chat-auto-parser-helpers.h#L68):
|
||||
|
||||
```cpp
|
||||
compare_variants(tmpl, params_A, params_modifier)
|
||||
```
|
||||
|
||||
This creates variant B by applying a modifier lambda to a copy of `params_A`, renders both through the template, and computes a `diff_split` ([common/chat-auto-parser.h:28-37](common/chat-auto-parser.h#L28-L37)):
|
||||
|
||||
- `prefix` — common prefix between A and B
|
||||
- `suffix` — common suffix between A and B
|
||||
- `left` — unique to variant A
|
||||
- `right` — unique to variant B
|
||||
|
||||
The diff is computed via `calculate_diff_split()`, which finds the longest-common-prefix and longest-common-suffix, then iteratively moves incomplete `<...>` or `[...]` markers from the prefix/suffix into left/right until stable (tag boundary fixing).
|
||||
|
||||
Text is segmentized into markers and non-marker fragments using `segmentize_markers()`, which splits on `<...>` and `[...]` boundaries.
|
||||
|
||||
### Phase 1: Reasoning Analysis
|
||||
|
||||
**R1 — `compare_reasoning_presence()`**: Compares assistant message with vs without a `reasoning_content` field.
|
||||
|
||||
- Searches `diff.right` (output with reasoning) for the reasoning content needle
|
||||
- Uses PEG parsers to find surrounding markers:
|
||||
- If both pre/post markers found in `diff.right` → `TAG_BASED` (both tags visible in diff = no forced close)
|
||||
- If both found but post marker only in the full output B → `FORCED_CLOSED`
|
||||
- If only post marker found → `DELIMITER`
|
||||
- Sets `reasoning.start` and `reasoning.end`
|
||||
|
||||
**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt.
|
||||
|
||||
- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN`
|
||||
- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker
|
||||
- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers
|
||||
|
||||
**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls.
|
||||
|
||||
- Only runs if `jinja_caps.supports_tool_calls`
|
||||
- Detects `TOOLS_ONLY`: reasoning content present in B (with tools) but not in A (with text content)
|
||||
- Extracts reasoning markers from the tool call output using PEG parsers
|
||||
|
||||
### Phase 2: Content Analysis
|
||||
|
||||
**C1**: Two comparisons in the `analyze_content` constructor:
|
||||
|
||||
- Comparison 1: content-only output vs tool-call output → `diff_tools`
|
||||
- Comparison 2: content-only output vs reasoning+empty-content output → `diff_reasoning`
|
||||
|
||||
Classification logic:
|
||||
|
||||
- `PLAIN`: `diff_tools.left` equals the response string (content is the entire diff, no wrapper)
|
||||
- `ALWAYS_WRAPPED`: markers found surrounding the content text in `pure_content` → extracts `start`/`end`
|
||||
|
||||
### Phase 3: Tool Call Analysis
|
||||
|
||||
**T1 — `analyze_tool_calls()`**: Compares no-tools vs with-tools output.
|
||||
|
||||
- Extracts the tool call section as `diff.right`
|
||||
- Calls `analyze_tool_call_format()` which first strips reasoning markers from the haystack, then:
|
||||
- Calls `in_json_haystack()` for both function name and argument name needles
|
||||
- `in_json_haystack()` uses a PEG parser to check whether the needle appears in a JSON context (preceded by `{` or `:` with surrounding quotes)
|
||||
- If function name is in JSON → `JSON_NATIVE` → `analyze_tool_call_format_json_native()`
|
||||
- If function name not in JSON, arg name is in JSON → `TAG_WITH_JSON`
|
||||
- If neither in JSON → `TAG_WITH_TAGGED`
|
||||
- `analyze_tool_call_format_json_native()`: parses the JSON object, matches field values to needles to populate `name_field`, `args_field`, `id_field`, `gen_id_field`; detects `tools_array_wrapped`; extracts `section_start`/`section_end`
|
||||
- `analyze_tool_call_format_non_json()`: uses PEG parsers on the haystack to find up to two opening markers (section + per-call) then up to two closing markers
|
||||
|
||||
**T2 — `check_per_call_markers()`**: Compares 1 call vs 2 calls.
|
||||
|
||||
- Computes a secondary diff of the second call portion vs the common suffix
|
||||
- If the second call content starts with `section_start` → the section marker is actually per-call → moves `section_start/end` to `per_call_start/end` and clears the section markers
|
||||
|
||||
**T3 — `extract_function_markers()`**: Compares function name `FUN_FIRST` vs `FUN_SECOND` (two different named functions).
|
||||
|
||||
- Finds where the function name appears in `diff.left`
|
||||
- Extracts `function.name_prefix` from the common prefix up to the function marker, and `function.name_suffix` from after the name up to the next marker
|
||||
- Extends `name_suffix` into `diff.suffix` (to the first marker for TAG_WITH_TAGGED; to the first `{` or `[` for TAG_WITH_JSON)
|
||||
- Extracts `function.close` from after the last argument value up to the per-call/section end marker
|
||||
|
||||
**T4 — `analyze_arguments()`** (TAG_WITH_TAGGED only):
|
||||
|
||||
- **A1 `extract_argument_name_markers()`**: Compares `arg_name_A` vs `arg_name_B` (two different argument names).
|
||||
- Finds shared surrounding structure → `arguments.name_prefix`, `arguments.name_suffix`
|
||||
- **A2 `extract_argument_value_markers()`**: Compares argument value `"XXXX"` vs `"YYYY"` (same arg, different value).
|
||||
- Finds markers surrounding the value → `arguments.value_prefix`, `arguments.value_suffix`
|
||||
|
||||
**T5 — `extract_argument_separator()`**: Compares 1 argument vs 2 arguments (same function).
|
||||
|
||||
- Uses `until_common_prefix(diff.right, ARG_FIRST, ARG_SECOND)` to find what separates the two argument blocks
|
||||
|
||||
**T6 — `extract_args_markers()`**: Compares 0 arguments vs 1 argument.
|
||||
|
||||
- Uses `until_common_prefix()` and `after_common_suffix()` with the empty and single-arg JSON strings as anchors to find container markers (`arguments.start`, `arguments.end`)
|
||||
|
||||
**T7 — `extract_call_id_markers()`**: Compares call IDs `"call00001"` vs `"call99999"`.
|
||||
|
||||
- Determines whether function name appears in `diff.prefix` or `diff.suffix` to classify position:
|
||||
- Function name in prefix only → `BETWEEN_FUNC_AND_ARGS` or `POST_ARGS` (further distinguished by where `{` appears)
|
||||
- Function name in suffix only → `PRE_FUNC_NAME`
|
||||
- Extracts `call_id.prefix` and `call_id.suffix` markers around the call ID value
|
||||
- Clears `per_call_end` if it incorrectly incorporated the call ID suffix
|
||||
|
||||
### Workarounds
|
||||
|
||||
A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds:
|
||||
|
||||
1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('</think>')`: sets `reasoning.mode = FORCED_OPEN` with `<think>`/`</think>` markers if no reasoning was detected
|
||||
2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with `<think>`/`</think>` and `WRAPPED_WITH_REASONING` content with `<response>`/`</response>`
|
||||
3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set
|
||||
4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers
|
||||
5. **DeepSeek-R1-Distill-Qwen** — source contains `tool▁calls▁begin` markers: overrides tool section/per-call markers with the correct Unicode block characters
|
||||
|
||||
### Parser Building
|
||||
|
||||
Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) implements `build_parser(parser_build_context&)`. They share a `parser_build_context` that carries the PEG builder, inference inputs, the pre-built reasoning parser, and a pointer to the content analyzer.
|
||||
|
||||
#### Reasoning Parser (`analyze_reasoning::build_parser`)
|
||||
|
||||
| Mode | Parser |
|
||||
|-----------------------------------|---------------------------------------------------------------------|
|
||||
| Not extracting reasoning | `eps()` |
|
||||
| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt |
|
||||
| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` |
|
||||
| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker |
|
||||
|
||||
#### Content Parser (`analyze_content::build_parser`)
|
||||
|
||||
| Condition | Parser |
|
||||
|----------------------------------------|---------------------------------------------------------------------------------|
|
||||
| `json_schema` present | `reasoning + space() + content(schema(json(), "response-format", ...)) + end()` |
|
||||
| Tools present | Dispatches to `analyze_tools::build_parser()` |
|
||||
| `ALWAYS_WRAPPED` with reasoning | `reasoning + start + content(until(end)) + end + end()` |
|
||||
| `ALWAYS_WRAPPED` without reasoning | `content(until(start)) + start + content(until(end)) + end + end()` |
|
||||
| Default (PLAIN) | `reasoning + content(rest()) + end()` |
|
||||
|
||||
#### Tool Parsers (`analyze_tools::build_parser`)
|
||||
|
||||
Dispatches by `format.mode`:
|
||||
|
||||
**`build_tool_parser_json_native()`**: Calls `p.standard_json_tools()` which internally dispatches to:
|
||||
|
||||
- `build_json_tools_function_is_key()` — function name is the JSON key: `{"get_weather": {...}}`
|
||||
- `build_json_tools_nested_keys()` — nested: `{"function": {"name": "X", "arguments": {...}}}`
|
||||
- `build_json_tools_flat_keys()` — flat: `{"name": "X", "arguments": {...}}`
|
||||
|
||||
Handles content wrappers, array wrapping (`tools_array_wrapped`), parallel calls, and `parameter_order`.
|
||||
|
||||
**`build_tool_parser_tag_json()`**: For each tool function:
|
||||
|
||||
```text
|
||||
tool_open(name_prefix + tool_name(literal(name)) + name_suffix) +
|
||||
call_id_section +
|
||||
tool_args(schema(json(), tool_schema))
|
||||
[+ function.close if non-empty]
|
||||
```
|
||||
|
||||
Wrapped in per-call markers (with optional parallel call repetition) then optionally in section markers.
|
||||
|
||||
**`build_tool_parser_tag_tagged()`**: For each tool function, builds one parser per argument:
|
||||
|
||||
- String types: `tool_arg_string_value(schema(until(value_suffix), ...))`
|
||||
- JSON types: `tool_arg_json_value(schema(json(), ...))`
|
||||
- Required args are plain; optional args wrapped in `optional()`
|
||||
- Arguments joined with `space()` between consecutive parsers
|
||||
|
||||
For closing: uses `function.close` if present; otherwise uses `peek(per_call_end)` to avoid premature close during partial streaming; falls back to `tool_close(space())` to trigger mapper callbacks.
|
||||
|
||||
All three tool parsers return:
|
||||
|
||||
```text
|
||||
reasoning + optional(content(until(trigger_marker))) + tool_calls + end()
|
||||
```
|
||||
|
||||
### Python Dict Format
|
||||
|
||||
When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule.
|
||||
|
||||
## Mapper
|
||||
|
||||
`common_chat_peg_mapper` maps PEG parse results (AST nodes) into `common_chat_msg` structures. Key design:
|
||||
|
||||
- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments`
|
||||
- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching
|
||||
- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format)
|
||||
- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`)
|
||||
- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|-------------------------------------------|----------------------------------------------------------------------|
|
||||
| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` |
|
||||
| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods |
|
||||
| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds |
|
||||
| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, |
|
||||
| | `compare_variants()`, string helpers |
|
||||
| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers |
|
||||
| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` |
|
||||
| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis |
|
||||
| `tools/parser/template-analysis.cpp` | Template analysis tool |
|
||||
|
||||
## Testing & Debugging
|
||||
|
||||
### Debug Tools
|
||||
|
||||
**Template Debugger**: `tools/parser/debug-template-parser.cpp`
|
||||
|
||||
- Usage: `./bin/llama-debug-template-parser path/to/template.jinja`
|
||||
- Shows detected format, markers, generated parser, and GBNF grammar
|
||||
|
||||
**Template Analysis**: `tools/parser/template-analysis.cpp`
|
||||
|
||||
- Usage: `./bin/llama-template-analysis path/to/template.jinja`
|
||||
|
||||
**Debug Logging**: Enable with `LLAMA_LOG_VERBOSITY=2`
|
||||
|
||||
- Shows detailed analysis steps, pattern extraction results, and generated parser structure
|
||||
|
||||
**PEG Test Builder**: Fluent API for creating test cases — see [tests/test-chat.cpp:947-1043](tests/test-chat.cpp#L947-L1043). Example usage:
|
||||
|
||||
```cpp
|
||||
auto tst = peg_tester("models/templates/Template.jinja");
|
||||
tst.test("input text")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({tool_json})
|
||||
.parallel_tool_calls(true)
|
||||
.enable_thinking(true)
|
||||
.expect(expected_message)
|
||||
.run();
|
||||
```
|
||||
|
||||
### Tested Templates
|
||||
|
||||
The following templates have active tests in `tests/test-chat.cpp`:
|
||||
|
||||
| Template | Format | Notes |
|
||||
| -------- | ------ | ----- |
|
||||
| Ministral-3-14B-Reasoning | Reasoning | `[THINK]...[/THINK]` tags (specialized handler) |
|
||||
| NVIDIA-Nemotron-3-Nano-30B | TAG_WITH_TAGGED | Reasoning + tools |
|
||||
| CohereForAI Command-R7B | JSON_NATIVE | `<\|START_THINKING\|>`/`<\|START_RESPONSE\|>` markers |
|
||||
| Google Gemma 2 2B | Content only | No tool support |
|
||||
| Qwen-QwQ-32B | Reasoning | Forced-open thinking |
|
||||
| NousResearch Hermes 2 Pro | JSON_NATIVE | `<tool_call>` wrapper |
|
||||
| IBM Granite 3.3 | JSON_NATIVE | `<think></think>` + `<response></response>` |
|
||||
| ByteDance Seed-OSS | TAG_WITH_TAGGED | Custom `<seed:think>` and `<seed:tool_call>` tags |
|
||||
| Qwen3-Coder | TAG_WITH_TAGGED | XML-style tool format |
|
||||
| DeepSeek V3.1 | JSON_NATIVE | Forced thinking mode |
|
||||
| GLM-4.6 | TAG_WITH_TAGGED | `<tool_call>name\n<arg_key>...<arg_value>...` format |
|
||||
| GLM-4.7-Flash | TAG_WITH_TAGGED | Updated GLM format |
|
||||
| Kimi-K2-Thinking | JSON_NATIVE | Reasoning + JSON tools |
|
||||
| Apertus-8B-Instruct | JSON_NATIVE | Function name as JSON key |
|
||||
| MiniMax-M2 | TAG_WITH_JSON | XML invoke with JSON args |
|
||||
| NVIDIA-Nemotron-Nano-v2 | JSON_NATIVE | `<TOOLCALL>` wrapper (nested) |
|
||||
| CohereForAI Command-R Plus | JSON_NATIVE | Markdown code block format |
|
||||
| Mistral-Nemo-Instruct-2407 | JSON_NATIVE | `[TOOL_CALLS]` wrapper with ID field |
|
||||
| Functionary v3.1 | TAG_WITH_JSON | `<function=X>` format |
|
||||
| Functionary v3.2 | Specialized | `>>>` recipient delimiter (dedicated handler) |
|
||||
| Fireworks Firefunction v2 | TAG_WITH_JSON | Fireworks tool format |
|
||||
| DeepSeek R1 Distill (Llama/Qwen) | Reasoning | Forced-open thinking |
|
||||
| llama-cpp-deepseek-r1 | Reasoning | Forced-open thinking |
|
||||
| Kimi-K2 / Kimi-K2-Instruct | JSON_NATIVE | JSON tools with special markers |
|
||||
| Llama 3.1/3.2/3.3 | JSON_NATIVE | Standard Llama tool format |
|
||||
| OpenAI GPT-OSS | Specialized | Channel-based (dedicated handler) |
|
||||
| Apriel 1.5 | JSON_NATIVE | `<tool_calls>` wrapper with JSON array |
|
||||
| Apriel 1.6 Thinker | Reasoning | Implicit reasoning start |
|
||||
| Mistral Small 3.2 | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` with call ID |
|
||||
| Devstral | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` without call ID |
|
||||
| StepFun 3.5 Flash | TAG_WITH_TAGGED | `<function=X><parameter=Y>` format |
|
||||
|
||||
## Adding Support for New Templates
|
||||
|
||||
To support a new template format:
|
||||
|
||||
1. **If it follows standard patterns** — The auto-parser should detect it automatically. Run `llama-debug-template-parser` to verify markers are correctly extracted.
|
||||
2. **If differential analysis extracts incorrect markers** — Add a workaround lambda to the `workarounds` vector in `common/chat-diff-analyzer.cpp`. Inspect the template source for a unique identifying substring.
|
||||
3. **If it needs fundamentally different handling** — Add a dedicated handler function in `chat.cpp` before the auto-parser block (as done for GPT-OSS, Functionary v3.2, and Ministral).
|
||||
|
||||
## Edge Cases and Quirks
|
||||
|
||||
1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., `<think>`), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker.
|
||||
2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker.
|
||||
3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built.
|
||||
4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `<tag>` or `[marker]` tokens, ensuring clean extraction.
|
||||
5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case.
|
||||
6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`.
|
||||
7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats.
|
||||
|
|
@ -22,7 +22,7 @@ Below is a contrived example demonstrating how to use the PEG parser to parse
|
|||
output from a model that emits arguments as JSON.
|
||||
|
||||
```cpp
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
for (const auto & tool : tools) {
|
||||
|
|
@ -212,7 +212,7 @@ mapper.from_ast(ctx.ast, result);
|
|||
|
||||
### Native
|
||||
|
||||
The `common_chat_peg_native_builder` builds a `native` parser suitable for
|
||||
The `common_chat_peg_builder` builds a `native` parser suitable for
|
||||
models that emit tool arguments as a direct JSON object.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
|
|
@ -225,7 +225,7 @@ models that emit tool arguments as a direct JSON object.
|
|||
- **`tool_args(p)`** - Tag the tool arguments
|
||||
|
||||
```cpp
|
||||
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
|
||||
build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open(p.literal("{")),
|
||||
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
|
||||
|
|
@ -246,7 +246,7 @@ build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
|
|||
|
||||
### Constructed
|
||||
|
||||
The `common_chat_peg_constructed_builder` builds a `constructed` parser
|
||||
The `common_chat_peg_builder` builds a `constructed` parser
|
||||
suitable for models that emit tool arguments as separate entities, such as XML
|
||||
tags.
|
||||
|
||||
|
|
@ -264,7 +264,7 @@ tags.
|
|||
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
|
||||
|
||||
```cpp
|
||||
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto location_arg = p.tool_arg(
|
||||
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
|
||||
p.tool_arg_string_value(p.until("</parameter>")),
|
||||
|
|
|
|||
|
|
@ -689,6 +689,11 @@ class SchemaConverter:
|
|||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||
|
||||
elif schema_type is None and isinstance(schema, dict):
|
||||
# No type constraint and no recognized structural keywords (e.g. {"description": "..."}).
|
||||
# Per JSON Schema semantics this is equivalent to {} and accepts any value.
|
||||
return self._add_rule(rule_name, self._add_primitive('value', PRIMITIVE_RULES['value']))
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
|
|
|
|||
|
|
@ -339,8 +339,8 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t
|
|||
}
|
||||
|
||||
static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
// TODO
|
||||
*free = 0;
|
||||
// no memory to report
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
|
|
|
|||
|
|
@ -2497,7 +2497,7 @@ class tinyBLAS_Q0_PPC {
|
|||
for (int r = 0; r < 8; r++) {
|
||||
const block_q4_0 * current_blk = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
||||
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
|
||||
vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
|
||||
vector signed char c1, c2;
|
||||
unpack_q4_to_q8(v_qs, c1, c2);
|
||||
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
|
|
@ -2611,14 +2611,14 @@ class tinyBLAS_Q0_PPC {
|
|||
i = (cols >> 2);
|
||||
if (i > 0) {
|
||||
do {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
|
||||
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
|
||||
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
||||
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
||||
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
||||
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
||||
c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
|
||||
c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
|
||||
c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
|
||||
c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
|
||||
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
|
|
@ -2657,10 +2657,10 @@ class tinyBLAS_Q0_PPC {
|
|||
i = (cols >> 2);
|
||||
if (i > 0) {
|
||||
do {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
||||
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
||||
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
||||
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
||||
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
|
|
@ -2686,9 +2686,9 @@ class tinyBLAS_Q0_PPC {
|
|||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
|
||||
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
||||
case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
||||
case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
||||
break;
|
||||
}
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
|
|
|
|||
|
|
@ -2129,12 +2129,12 @@ static void ggml_compute_forward_gelu_f32(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2176,13 +2176,13 @@ static void ggml_compute_forward_gelu_f16(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2325,12 +2325,12 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2372,13 +2372,13 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2444,12 +2444,12 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2491,13 +2491,13 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2563,12 +2563,12 @@ static void ggml_compute_forward_silu_f32(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
||||
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2610,13 +2610,13 @@ static void ggml_compute_forward_silu_f16(
|
|||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
||||
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x);
|
||||
GGML_UNUSED(v);
|
||||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2766,7 +2766,7 @@ static void ggml_compute_forward_silu_back_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2802,7 +2802,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|||
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
||||
|
||||
#ifndef NDEBUG
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x);
|
||||
|
|
@ -2810,7 +2810,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2893,7 +2893,7 @@ static void ggml_compute_forward_reglu_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2953,7 +2953,7 @@ static void ggml_compute_forward_reglu_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3036,7 +3036,7 @@ static void ggml_compute_forward_geglu_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3096,7 +3096,7 @@ static void ggml_compute_forward_geglu_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3179,7 +3179,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3239,7 +3239,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3330,7 +3330,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3409,7 +3409,7 @@ static void ggml_compute_forward_geglu_erf_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3469,7 +3469,7 @@ static void ggml_compute_forward_geglu_erf_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3552,7 +3552,7 @@ static void ggml_compute_forward_geglu_quick_f32(
|
|||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3612,7 +3612,7 @@ static void ggml_compute_forward_geglu_quick_f16(
|
|||
assert(!isnan(v));
|
||||
assert(!isinf(v));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5303,7 +5303,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
assert(!isnan(wp[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(ne00, &max, wp);
|
||||
|
|
@ -5328,7 +5328,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
assert(!isnan(dp[i]));
|
||||
assert(!isinf(dp[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5402,7 +5402,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|||
assert(!isnan(dy[i]));
|
||||
assert(!isnan(y[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
// Jii = yi - yi*yi
|
||||
// Jij = -yi*yj
|
||||
// J = diag(y)-y.T*y
|
||||
|
|
@ -5435,7 +5435,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|||
assert(!isnan(dx[i]));
|
||||
assert(!isinf(dx[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5803,28 +5803,33 @@ static void ggml_compute_forward_rope_flt(
|
|||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
int64_t last_i2 = -1;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!mrope_used) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir++ < ir0) continue; // skip rows mapped to other threads
|
||||
if (ir > ir1) break;
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (last_i2 != i2) {
|
||||
if (!mrope_used) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
last_i2 = i2;
|
||||
}
|
||||
|
||||
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
|
||||
|
|
@ -10700,7 +10705,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|||
assert(!isnan(s0[i]));
|
||||
assert(!isnan(s1[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(nc, &max, s0);
|
||||
|
|
@ -10719,7 +10724,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|||
assert(!isnan(st[i]));
|
||||
assert(!isinf(st[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
sums[ith] = sum_thread;
|
||||
ggml_barrier(params->threadpool);
|
||||
|
|
@ -10792,7 +10797,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|||
assert(!isnan(s0[i]));
|
||||
assert(!isnan(s1[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
|
||||
// soft_max
|
||||
float max = -INFINITY;
|
||||
|
|
@ -10810,7 +10815,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|||
assert(!isnan(ds0[i]));
|
||||
assert(!isinf(ds0[i]));
|
||||
}
|
||||
#endif
|
||||
#endif // NDEBUG
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
|
||||
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * mul = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != mul->type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
|
||||
if (other->type != unary->type) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
|
||||
const ggml_tensor *scale = cgraph->nodes[node_idx];
|
||||
|
|
@ -3372,6 +3412,69 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return false;
|
||||
}
|
||||
|
||||
// returns whether the write (out) nodes overwrite the read nodes in operation
|
||||
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
int node_count,
|
||||
int * out_nodes,
|
||||
int out_count) {
|
||||
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
|
||||
const int64_t a_start = (int64_t) a->data;
|
||||
const int64_t a_end = a_start + ggml_nbytes(a);
|
||||
|
||||
const int64_t b_start = (int64_t) b->data;
|
||||
const int64_t b_end = b_start + ggml_nbytes(b);
|
||||
|
||||
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_ok = true;
|
||||
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
|
||||
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < out_count; ++i) {
|
||||
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
|
||||
|
||||
for (int j = node_idx; j < node_idx + node_count; ++j) {
|
||||
// Loop over all srcs of all nodes in the fusion. If the src overlaps
|
||||
// the destination and the src is not an intermediate node that's being
|
||||
// elided, then disable fusion.
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
|
||||
|
||||
if (!src || src->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (nodes_overlap(dst, src)) {
|
||||
bool found = false;
|
||||
|
||||
for (int k = node_idx; k < j; ++k) {
|
||||
if (cgraph->nodes[k] == src) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
is_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_ok;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
|
|
@ -3568,7 +3671,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
out_nodes[1] = i + ops.size() - 1;
|
||||
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
|
||||
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
|
|
@ -3583,7 +3687,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
|
||||
int out_nodes[2] = { i + 1, i + 5 };
|
||||
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
|
||||
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
|
||||
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
|
||||
i += ops.size() - 1;
|
||||
continue;
|
||||
|
|
@ -3836,6 +3941,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
|
||||
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "ssm-conv.cuh"
|
||||
#include "unary.cuh"
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv>
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
|
||||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
|
|
@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
|||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
|
||||
|
|
@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
const int stride_w = src1_nb1 / sizeof(float);
|
||||
const int stride_y = dst_nb1 / sizeof(float);
|
||||
|
||||
float x[d_conv] = { 0.0f };
|
||||
float w[d_conv] = { 0.0f };
|
||||
const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
|
||||
const int n_cols = d_conv - 1 + split_n_t;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
|
||||
constexpr int load_cols = d_conv - 1 + split_n_t;
|
||||
constexpr int total_elems = split_d_inner * load_cols;
|
||||
int row = tid / load_cols;
|
||||
int col = tid % load_cols;
|
||||
#pragma unroll
|
||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
||||
if (row < (int)split_d_inner) {
|
||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||
}
|
||||
|
||||
col += split_d_inner;
|
||||
row += col / load_cols;
|
||||
col = col % load_cols;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load weights into registers (done once, small)
|
||||
float w[d_conv] = { 0.0f };
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
// Compute from shared memory
|
||||
for (int64_t i = 0; i < local_n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < split_n_t; i++) {
|
||||
if (bidz * split_n_t + i < n_t) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
if (i == 0) {
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
x[j] = x_block[tid * stride_x + j];
|
||||
}
|
||||
} else {
|
||||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += smem[tid * n_cols + i + j] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu>
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
|
||||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
|
||||
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
|
||||
|
|
@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
|
||||
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
}
|
||||
};
|
||||
|
|
@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||
const bool fuse_silu = silu_dst != nullptr;
|
||||
|
||||
// When fusing, write to silu_dst (the node downstream references).
|
||||
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
|
||||
|
||||
const int64_t nc = src1->ne[0]; // d_conv
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = dst->ne[1]; // tokens per sequence
|
||||
const int64_t n_s = dst->ne[2]; // number of sequences in the batch
|
||||
const int64_t n_t = out->ne[1]; // tokens per sequence
|
||||
const int64_t n_s = out->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nr);
|
||||
GGML_ASSERT(out->ne[0] == nr);
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
float * dst_d = (float *) out->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
|
||||
dst->nb[2], nc, nr, n_t, n_s, stream);
|
||||
GGML_ASSERT(out->type == GGML_TYPE_F32);
|
||||
if (fuse_silu) {
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
} else {
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
|
||||
|
|
|
|||
|
|
@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
}
|
||||
}
|
||||
|
||||
// Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs.
|
||||
// NaN comparisons always return false, which would cause the same expert to be
|
||||
// selected repeatedly. -FLT_MAX compares normally and is still excluded by the
|
||||
// -INFINITY sentinel used after each selection round.
|
||||
// More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
if (__isnanf(wt[i])) {
|
||||
wt[i] = -FLT_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
// selection_wt is only needed when bias is present (selection uses wt + bias)
|
||||
// when no bias, we use wt directly for both selection and weight values
|
||||
float selection_wt[has_bias ? experts_per_thread : 1];
|
||||
|
|
|
|||
|
|
@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
||||
}
|
||||
}
|
||||
|
||||
/* fused unary + mul */
|
||||
|
||||
template <float (*op)(float)>
|
||||
static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
|
||||
// unary_node: UNARY op applied to unary_node->src[0]
|
||||
// mul_node: MUL(a, b) where one of a/b is unary_node
|
||||
// Output goes to mul_node->data
|
||||
|
||||
const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op
|
||||
const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(unary_src));
|
||||
GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(other_src));
|
||||
GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
|
||||
GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
|
||||
|
||||
GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(unary_src->type == other_src->type);
|
||||
GGML_ASSERT(unary_src->type == mul_node->type);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t k = ggml_nelements(mul_node);
|
||||
const int64_t nc = unary_src->ne[0];
|
||||
const int64_t unary_stride = unary_src->nb[1];
|
||||
const int64_t other_stride = other_src->nb[1];
|
||||
|
||||
if (unary_src->type == GGML_TYPE_F16) {
|
||||
unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,
|
||||
(half *) mul_node->data, k, nc,
|
||||
unary_stride / sizeof(half), other_stride / sizeof(half), stream);
|
||||
} else {
|
||||
unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,
|
||||
(float *) mul_node->data, k, nc,
|
||||
unary_stride / sizeof(float), other_stride / sizeof(float), stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
|
||||
switch (ggml_get_unary_op(unary_node)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported unary op for fused unary+mul");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Only support FP32 for now
|
||||
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check IO tensor shapes and dims
|
||||
if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {
|
||||
return false; // src0 should be effectively 3D
|
||||
}
|
||||
|
||||
const int d_conv = src1->ne[0];
|
||||
const int d_inner = src0->ne[1];
|
||||
const int n_t = dst->ne[1];
|
||||
const int n_s = dst->ne[2];
|
||||
|
||||
if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {
|
||||
return false;
|
||||
}
|
||||
if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {
|
||||
return false;
|
||||
}
|
||||
if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contiguous tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
enum dspqbuf_type {
|
||||
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
|
||||
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
|
||||
|
|
@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_SSM_CONV;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->name.c_str();
|
||||
|
|
@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
supp = ggml_hexagon_supported_ssm_conv(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED
|
|||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
argsort-ops.c
|
||||
ssm-conv.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ enum htp_op {
|
|||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -41,9 +41,6 @@ struct htp_ops_context {
|
|||
worker_pool_context_t * wpool; // worker pool
|
||||
uint32_t n_threads; // num threads
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx);
|
|||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
|
|||
|
|
@ -15,4 +15,12 @@
|
|||
#include "hvx-div.h"
|
||||
#include "hvx-base.h"
|
||||
|
||||
#ifndef GATHER_TYPE
|
||||
# if defined(__hexagon__)
|
||||
# define GATHER_TYPE(_a) (intptr_t) _a
|
||||
# else
|
||||
# define GATHER_TYPE(_a) (HVX_Vector *) _a
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
|
|
|||
|
|
@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req *
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
// We've written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup OP context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_ssm_conv(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
|
|
@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_argsort_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SSM_CONV:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad ssm-conv-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_ssm_conv_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,339 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "hex-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict src1 = &octx->src1; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_ssm_conv_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
#define htp_ssm_conv_preamble \
|
||||
struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
|
||||
struct htp_ops_context * octx = scctx->octx; \
|
||||
htp_ssm_conv_tensors_preamble; \
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
// Scalar FP32 SSM_CONV implementation
|
||||
static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
htp_ssm_conv_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t d_conv = src1->ne[0];
|
||||
const uint32_t d_inner = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
|
||||
const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension
|
||||
const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension
|
||||
const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension
|
||||
const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension
|
||||
const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
const float * src1_data = (const float *) src1->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Calculate row range for this thread
|
||||
const uint32_t d_inner_per_thread = scctx->nrows_per_thread;
|
||||
const uint32_t d_inner_start = d_inner_per_thread * ith;
|
||||
const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner);
|
||||
|
||||
// No work for this thread
|
||||
if (d_inner_start >= d_inner_end) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
|
||||
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
|
||||
for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq;
|
||||
const uint32_t src1_idx = i0 + i1 * src1_stride_inner;
|
||||
|
||||
sumf += src0_data[src0_idx] * src1_data[src1_idx];
|
||||
}
|
||||
|
||||
const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq;
|
||||
dst_data[dst_idx] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end,
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
|
||||
static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
|
||||
htp_ssm_conv_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const int nc = src1->ne[0]; // d_conv
|
||||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||
|
||||
const uint32_t d_conv = src1->ne[0];
|
||||
const uint32_t d_inner = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
const float * src1_data = (const float *) src1->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Calculate row range for this thread
|
||||
const int dr = scctx->nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, d_inner);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
if (ir0 >= ir1) {
|
||||
return; // No work for this thread
|
||||
}
|
||||
|
||||
// src0 and src1 gather offsets
|
||||
uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
|
||||
uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
|
||||
|
||||
for (uint32_t i = 0; i < VLEN_FP32; ++i) {
|
||||
src0_offsets[i] = i * (ncs) * sizeof(float);
|
||||
src1_offsets[i] = i * (d_conv) * sizeof(float);
|
||||
}
|
||||
|
||||
const uint32_t src0_gather_len = VLEN * ncs;
|
||||
const uint32_t src1_gather_len = VLEN * d_conv;
|
||||
|
||||
// gather scratchpads
|
||||
HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
|
||||
HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
|
||||
|
||||
float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
|
||||
float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
|
||||
|
||||
uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
|
||||
uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
|
||||
|
||||
// copy src1 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
|
||||
|
||||
for (uint32_t i3 = 0; i3 < n_s; ++i3) {
|
||||
float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
|
||||
|
||||
// copy src0 workload to VTCM
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
|
||||
|
||||
// FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
for (uint32_t i2 = 0; i2 < n_t; ++i2) {
|
||||
float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
|
||||
|
||||
const uint32_t nvec = ir / VLEN_FP32;
|
||||
const uint32_t nloe = ir % VLEN_FP32;
|
||||
uint32_t i1 = 0;
|
||||
|
||||
for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
|
||||
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
|
||||
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
}
|
||||
|
||||
*(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
|
||||
i1 += VLEN_FP32;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector acc_vec = Q6_V_vsplat_R(0);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
|
||||
Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
|
||||
src0_gather_len, (*(const HVX_Vector *) src0_offsets));
|
||||
Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
|
||||
src1_gather_len, (*(const HVX_Vector *) src1_offsets));
|
||||
|
||||
HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
|
||||
acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
||||
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
htp_ssm_conv_tensors_preamble;
|
||||
|
||||
if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
|
||||
FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported");
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
struct htp_ssm_conv_context scctx = { 0 };
|
||||
scctx.octx = octx;
|
||||
|
||||
const uint32_t d_conv = src1->ne[0];
|
||||
const uint32_t d_inner = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1]; // tokens per sequence
|
||||
const uint32_t n_s = dst->ne[2]; // number of sequences in the batch
|
||||
|
||||
const uint32_t n_threads = MIN(octx->n_threads, d_inner);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t use_hvx = 0;
|
||||
if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
|
||||
int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
|
||||
hex_is_aligned((void *) src1->data, VLEN) &&
|
||||
hex_is_aligned((void *) dst->data, VLEN);
|
||||
|
||||
if (is_aligned) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_hvx) {
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
|
||||
|
||||
octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
|
||||
|
||||
// Compute gather scratchpad size for src0 and src1
|
||||
const size_t gather_spad_size = n_threads * VLEN * 2;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
|
||||
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
|
||||
octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
|
||||
octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
|
||||
|
||||
const size_t total_spad_size =
|
||||
gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (total_spad_size > octx->ctx->vtcm_size) {
|
||||
FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
|
||||
octx->ctx->vtcm_size);
|
||||
use_hvx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0],
|
||||
src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
|
||||
dst->ne[1], dst->ne[2], dst->ne[3], use_hvx);
|
||||
|
||||
if (use_hvx) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads);
|
||||
} else {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_ssm_conv(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = op_ssm_conv_f32(octx);
|
||||
break;
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -5345,7 +5345,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
|
|||
}
|
||||
|
||||
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
*free = 0;
|
||||
// no memory to report
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
|
|
|
|||
|
|
@ -97,20 +97,20 @@
|
|||
|
||||
{%- macro render_tools(tools) -%}
|
||||
{%- for tool in tools %}
|
||||
{{- "// " + tool.description + "\n" }}
|
||||
{{- "type "+ tool.name + " = " }}
|
||||
{%- if tool.parameters and tool.parameters.properties %}
|
||||
{{- "// " + tool.function.description + "\n" }}
|
||||
{{- "type "+ tool.function.name + " = " }}
|
||||
{%- if tool.function.parameters and tool.function.parameters.properties %}
|
||||
{{- "(_: {\n" }}
|
||||
{%- for param_name, param_spec in tool.parameters.properties.items() %}
|
||||
{%- for param_name, param_spec in tool.function.parameters.properties.items() %}
|
||||
{%- if param_spec.description %}
|
||||
{{- "// " + param_spec.description + "\n" }}
|
||||
{%- endif %}
|
||||
{{- param_name }}
|
||||
{%- if param_name not in (tool.parameters.required or []) -%}
|
||||
{%- if param_name not in (tool.function.parameters.required or []) -%}
|
||||
{{- "?" }}
|
||||
{%- endif -%}
|
||||
{{- ": " }}
|
||||
{{- render_typescript_type(param_spec, tool.parameters.required or []) }}
|
||||
{{- render_typescript_type(param_spec, tool.function.parameters.required or []) }}
|
||||
{%- if param_spec.default is defined -%}
|
||||
{%- if param_spec.enum %}
|
||||
{{- ", // default: " + param_spec.default }}
|
||||
|
|
@ -294,7 +294,7 @@
|
|||
{%- for tool_call in message.tool_calls -%}
|
||||
{%- if tool_call.type == 'function' -%}
|
||||
{%- set function = tool_call.function -%}
|
||||
{{- '{"' + function.name + '": ' + function.arguments + '}' }}
|
||||
{{- '{"' + function.name + '": ' + function.arguments|tojson + '}' }}
|
||||
{%- if not loop.last -%}
|
||||
{{- ", " }}
|
||||
{%- endif -%}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,172 @@
|
|||
{# ---------------------------------------------------------------------- #}
|
||||
{# ƛƬ Default setup and flags #}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{%- set messages = messages or [] -%}
|
||||
{%- set tools = tools or [] -%}
|
||||
{%- set add_generation_prompt = add_generation_prompt or false -%}
|
||||
{%- set available_tool_string = '' -%}
|
||||
{%- set add_tool_id = true -%}
|
||||
{%- set add_thoughts = true -%} {# whether to include <thinking> reasoning blocks #}
|
||||
{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #}
|
||||
{# Optional token placeholders (safe defaults) #}
|
||||
{%- set bos_token = bos_token or '' -%}
|
||||
{%- set eos_token = eos_token or '' -%}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{# Core reasoning prompt and assistant reasoning prefix #}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{%- set reasoning_prompt -%}
|
||||
You are a thoughtful, systematic AI assistant from ServiceNow Language Models (SLAM) lab.
|
||||
Analyze each question carefully, present your reasoning step-by-step, then provide the final
|
||||
response after the marker [BEGIN FINAL RESPONSE].
|
||||
{%- endset -%}
|
||||
{%- set reasoning_asst_turn_start = 'Here are my reasoning steps:\n' -%}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{# Tool list and tool call output format #}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{%- if tools|length > 0 -%}
|
||||
{%- set available_tool_string -%}
|
||||
You are provided with function signatures within <available_tools></available_tools> XML tags.
|
||||
You may call one or more functions to assist with the user query.
|
||||
Don't make assumptions about the arguments. You should infer the argument values from previous
|
||||
user responses and the system message.
|
||||
Here are the available tools:
|
||||
<available_tools>
|
||||
{% for tool in tools %}{{ tool|string }}{% endfor %}
|
||||
|
||||
</available_tools>.
|
||||
|
||||
Return all function calls as a list of JSON objects within <tool_calls></tool_calls> XML tags.
|
||||
Each JSON object should contain a function name and arguments as follows:
|
||||
<tool_calls>[
|
||||
{"name": <function-name-1>, "arguments": <args-dict-1>},
|
||||
{"name": <function-name-2>, "arguments": <args-dict-2>},
|
||||
...
|
||||
]</tool_calls>
|
||||
{%- endset -%}
|
||||
{%- endif -%}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{# Start system block if first message is not system #}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{%- if messages|length > 0 and messages[0]['role'] != 'system' -%}
|
||||
{%- if tools|length > 0 -%}
|
||||
{{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + available_tool_string + '\n' }}
|
||||
{%- else -%}
|
||||
{{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{# Iterate through messages #}
|
||||
{# ---------------------------------------------------------------------- #}
|
||||
{%- for message in messages -%}
|
||||
|
||||
{# ---------------- USER MESSAGE ---------------- #}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{ '<|begin_user|>\n' }}
|
||||
{%- if message['content'] is not string -%}
|
||||
{%- for chunk in message['content'] -%}
|
||||
{%- if chunk['type'] == 'text' -%}
|
||||
{{ chunk['text'] }}
|
||||
{%- elif chunk['type'] in ['image', 'image_url'] -%}
|
||||
{{ '[IMG]' }}
|
||||
{%- else -%}
|
||||
{{ raise_exception('Unrecognized content type!') }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ message['content'] }}
|
||||
{%- endif -%}
|
||||
|
||||
{# ---------------- SYSTEM MESSAGE ---------------- #}
|
||||
{%- elif message['role'] == 'system' -%}
|
||||
{%- set sys_content = message.get('content', '') -%}
|
||||
{%- if sys_content and sys_content|length > 0 -%}
|
||||
{%- if sys_content is string -%}
|
||||
{%- set system_message = sys_content -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = sys_content[0]['text'] -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools|length > 0 -%}
|
||||
{{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + system_message + '\n' + available_tool_string + '\n' }}
|
||||
{%- else -%}
|
||||
{{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + system_message + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{# ---------------- ASSISTANT MESSAGE ---------------- #}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- if loop.last -%}
|
||||
{%- set add_tool_id = false -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{ '\n<|begin_assistant|>\n' }}
|
||||
|
||||
{%- if add_thoughts and message.get('reasoning_content') and loop.last -%}
|
||||
{{ message['reasoning_content'] + '\n[BEGIN FINAL RESPONSE]\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set asst_content = message.get('content', '') -%}
|
||||
{%- if asst_content and asst_content|length > 0 -%}
|
||||
{%- if asst_content is not string -%}
|
||||
{%- set asst_text = asst_content[0]['text'] -%}
|
||||
{%- else -%}
|
||||
{%- set asst_text = asst_content -%}
|
||||
{%- endif -%}
|
||||
{# For historical turns (not the last), strip reasoning and keep only final response #}
|
||||
{%- if not loop.last and '[BEGIN FINAL RESPONSE]' in asst_text -%}
|
||||
{{- asst_text.split('[BEGIN FINAL RESPONSE]')[-1] | trim -}}
|
||||
{%- else -%}
|
||||
{{- asst_text -}}
|
||||
{%- endif -%}
|
||||
{%- elif message.get('chosen') and message['chosen']|length > 0 -%}
|
||||
{{ message['chosen'][0] }}
|
||||
{%- endif -%}
|
||||
|
||||
{# Tool call output #}
|
||||
{%- set tool_calls = message.get('tool_calls', []) -%}
|
||||
{%- if tool_calls and tool_calls|length > 0 -%}
|
||||
{{ '\n<tool_calls>[' }}
|
||||
{%- for tool_call in tool_calls -%}
|
||||
{{ '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|tojson }}
|
||||
{%- if add_tool_id == true and 'id' in tool_call -%}
|
||||
{{ ', "id": "' + tool_call['id'] + '"' }}
|
||||
{%- endif -%}
|
||||
{{ '}' }}
|
||||
{%- if not loop.last -%}{{ ', ' }}{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{ ']</tool_calls>' }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set training_prompt = training_prompt if (training_prompt is defined) else false -%}
|
||||
{%- if not loop.last or training_prompt -%}
|
||||
{{ '\n<|end|>\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{# ---------------- TOOL RESULT MESSAGE ---------------- #}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- set tool_content = message.get('content', '') -%}
|
||||
{%- if tool_content is string -%}
|
||||
{%- set tool_message = tool_content -%}
|
||||
{%- else -%}
|
||||
{%- set tool_message = tool_content[0]['text'] if tool_content else '' -%}
|
||||
{%- endif -%}
|
||||
{{ '<|begin_tool_result|>\n' + tool_message|string + '\n' }}
|
||||
|
||||
{# ---------------- CONTENT MESSAGE ---------------- #}
|
||||
{%- elif message['role'] == 'content' -%}
|
||||
{%- set msg_content = message.get('content', '') -%}
|
||||
{%- if msg_content is not string -%}
|
||||
{{ '<|begin_content|>\n' + msg_content[0]['text'] + '\n' }}
|
||||
{%- else -%}
|
||||
{{ '<|begin_content|>\n' + msg_content + '\n' }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
{# ---------------- REASONING PROMPT BEFORE NEXT ASSISTANT ---------------- #}
|
||||
{%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%}
|
||||
{{ '\n<|begin_assistant|>\n' + reasoning_asst_turn_start }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- endfor -%}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
{{ bos_token }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_start_index = 1 %}
|
||||
{%- else %}
|
||||
{%- set system_message = "" %}
|
||||
{%- set loop_start_index = 0 %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if system_message or tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if system_message %}
|
||||
{{- system_message }}
|
||||
{%- endif %}
|
||||
{%- if tools %}
|
||||
{{- '\n\nMasz dostęp do następujących narzędzi. Definicje narzędzi znajdują się poniżej wewnątrz znaczników <|function_list|>:\n<|function_list|>\n' }}
|
||||
{{- '[' }}
|
||||
{%- for tool in tools %}
|
||||
{{- tool | tojson }}
|
||||
{%- if not loop.last %}
|
||||
{{- ',\n' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- ']\n<|function_list|>\n\nAby wywołać narzędzie, użyj formatu <tool_call>{"name": "nazwa_narzędzia", "arguments": {"argument": "wartość"}}</tool_call>. Wyniki działania narzędzi zostaną przekazane z markerem <|function_output|>.\n' }}
|
||||
{%- endif %}
|
||||
{%- if enable_thinking %}
|
||||
{{- '\n\nZanim odpowiesz na pytanie, najpierw przemyśl swoje kroki i umieść swoje myśli wewnątrz tagów <think>...</think>. Musisz najpierw pomyśleć, zanim udzielisz odpowiedzi. ' -}}
|
||||
{{- 'WAŻNE: Powinieneś myśleć w tym samym języku, co pytanie użytkownika. Jeśli pytanie jest zadane po polsku, powinieneś również myśleć po polsku. Jeśli pytanie jest po angielsku, myślisz również po angielsku itd. ' }}
|
||||
{{- '** PAMIĘTAJ! ** Pytanie po polsku -> myślenie po polsku -> odpowiedź po polsku!' -}}
|
||||
{{- '*** BARDZO WAŻNE!!! *** Jesteś Bielikiem, polskim modelem językowym. Twoją główną cechą jest umiejętność pisania po polsku. Jeśli użytkownik zadaje Ci pytania po polsku, ZAWSZE odpowiadaj po polsku. ' -}}
|
||||
{{- 'Nawet, jeśli korzystasz z narzędzia, którego większość instrukcji jest po angielsku, powinieneś przede wszystkim odpowiadać po polsku, jeśli użytkownik zadaje pytanie w tym języku. ' -}}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages[loop_start_index:] %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{{- '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- set content = message.content | default('') %}
|
||||
{%- set reasoning_content = message.reasoning_content | default('') %}
|
||||
{%- if not reasoning_content and '<think>' in content and '</think>' in content %}
|
||||
{%- set reasoning_parts = content.split('</think>') %}
|
||||
{%- set reasoning_content = reasoning_parts[0].split('<think>')[-1] %}
|
||||
{%- set content = reasoning_parts[1:] | join('</think>') %}
|
||||
{%- endif %}
|
||||
{%- if reasoning_content %}
|
||||
{{- '<think>\n' + reasoning_content.strip() + '\n</think>\n' }}
|
||||
{%- endif %}
|
||||
{{- content.lstrip() }}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_call>\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message['role'] == 'tool' %}
|
||||
{%- if loop.index0 == 0 or messages[loop.index0 - 1]['role'] != 'tool' %}
|
||||
{{- '<|im_start|>user\n' }}
|
||||
{%- endif %}
|
||||
{{- '<|function_output|>' + message['content'] }}
|
||||
{%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking %}
|
||||
{{- '<think>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
|
@ -132,7 +132,7 @@ The following instructions take precedence over instructions in the default prea
|
|||
{%- elif message.role|lower == 'user' %}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
|
||||
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.reasoning_content}}<|END_THINKING|><|START_ACTION|>[
|
||||
{% for tc in message.tool_calls %}
|
||||
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
|
||||
|
||||
|
|
@ -153,4 +153,4 @@ The following instructions take precedence over instructions in the default prea
|
|||
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
|
||||
{%- endif %}
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if not enable_thinking -%}<|START_THINKING|><|END_THINKING|>{%- endif %}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
[gMASK]<sop>
|
||||
{%- if tools -%}
|
||||
<|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{% for tool in tools %}
|
||||
{{ tool | tojson(ensure_ascii=False) }}
|
||||
{% endfor %}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>{%- endif -%}
|
||||
{%- macro visible_text(content) -%}
|
||||
{%- if content is string -%}
|
||||
{{- content }}
|
||||
{%- elif content is iterable and content is not mapping -%}
|
||||
{%- for item in content -%}
|
||||
{%- if item is mapping and item.type == 'text' -%}
|
||||
{{- item.text }}
|
||||
{%- elif item is string -%}
|
||||
{{- item }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{- content }}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- set ns = namespace(last_user_index=-1) %}
|
||||
{%- for m in messages %}
|
||||
{%- if m.role == 'user' %}
|
||||
{% set ns.last_user_index = loop.index0 -%}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% for m in messages %}
|
||||
{%- if m.role == 'user' -%}<|user|>{{ visible_text(m.content) }}
|
||||
{%- elif m.role == 'assistant' -%}
|
||||
<|assistant|>
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- set content = visible_text(m.content) %}
|
||||
{%- if m.reasoning_content is string %}
|
||||
{%- set reasoning_content = m.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if ((clear_thinking is defined and not clear_thinking) or loop.index0 > ns.last_user_index) and reasoning_content -%}
|
||||
{{ '<think>' + reasoning_content.strip() + '</think>'}}
|
||||
{%- else -%}
|
||||
{{ '</think>' }}
|
||||
{%- endif -%}
|
||||
{%- if content.strip() -%}
|
||||
{{ content.strip() }}
|
||||
{%- endif -%}
|
||||
{% if m.tool_calls %}
|
||||
{% for tc in m.tool_calls %}
|
||||
{%- if tc.function %}
|
||||
{%- set tc = tc.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>' + tc.name -}}
|
||||
{% set _args = tc.arguments %}{% for k, v in _args.items() %}<arg_key>{{ k }}</arg_key><arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>{% endfor %}</tool_call>{% endfor %}
|
||||
{% endif %}
|
||||
{%- elif m.role == 'tool' -%}
|
||||
{%- if m.content is string -%}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|observation|>' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>' }}
|
||||
{{- m.content }}
|
||||
{{- '</tool_response>' }}
|
||||
{%- else -%}
|
||||
<|observation|>{% for tr in m.content %}
|
||||
<tool_response>{{ tr.output if tr.output is defined else tr }}</tool_response>{% endfor -%}
|
||||
{% endif -%}
|
||||
{%- elif m.role == 'system' -%}
|
||||
<|system|>{{ visible_text(m.content) }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|assistant|>{{- '</think>' if (enable_thinking is defined and not enable_thinking) else '<think>' -}}
|
||||
{%- endif -%}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
{{- bos_token -}}
|
||||
{%- set system_prompt = "" -%}
|
||||
{%- set ns = namespace(system_prompt="") -%}
|
||||
{%- if messages[0]["role"] == "system" -%}
|
||||
{%- set ns.system_prompt = messages[0]["content"] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "You can use the following tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
||||
{%- if not loop.last -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{{- '**IMPORTANT**: The syntax for calling the tools is: <|tool_call_start|>JSON tool call goes here<|tool_call_end|>. Please only call tools in the specified manner.' -}}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
||||
{%- set content = message["content"] -%}
|
||||
{%- if content is not string -%}
|
||||
{%- set content = content | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- elif message["role"] == "assistant" -%}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<|tool_call_start|>\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n<|tool_call_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
|
|
@ -59,4 +59,5 @@
|
|||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||
{%- if not enable_thinking -%}{{- '</think>' -}}{%- endif -%}
|
||||
{%- endif %}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@
|
|||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
|
||||
{{- "\n\n# Tools\n\nYou have access to the following tools:\n\n" }}
|
||||
{{- "<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
|
|
@ -63,7 +63,7 @@
|
|||
{{- '\n</function>' }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>" }}
|
||||
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||
{{- '\n\nIf you choose to call a tool ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nvalue_2\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: the tool calling block MUST begin with an opening <tool_call> tag and end with a closing </tool_call> tag.\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||
{%- endif %}
|
||||
{%- if system_message is defined %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
{% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
|
||||
{{bos_token}}{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- render_content(messages[0].content) + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson(ensure_ascii=False) }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- set content = render_content(message.content) %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
|
||||
{{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = render_content(message.reasoning_content) %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- else %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{%- set arguments = tool_call.arguments %}
|
||||
{%- for args_name, args_value in arguments|items %}
|
||||
{{- '<parameter=' + args_name + '>\n' }}
|
||||
{%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
||||
{{- args_value }}
|
||||
{{- '\n</parameter>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>tool_response\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>' }}
|
||||
{{- content }}
|
||||
{{- '</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||
{%- endif %}
|
||||
|
|
@ -1 +1,44 @@
|
|||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
{% if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.system_prompt = message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}{{bos_token}}{{ns.system_prompt}}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none -%}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']-%}
|
||||
{%- if not ns.is_first -%}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else -%}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
|
||||
{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else -%}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>')[-1] -%}
|
||||
{%- endif -%}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'tool' -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first -%}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false -%}
|
||||
{%- else -%}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- if add_generation_prompt and not ns.is_tool -%}{{'<|Assistant|><think>\n'}}
|
||||
{%- endif %}
|
||||
|
|
@ -1 +1,47 @@
|
|||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}
|
||||
{% if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.system_prompt = message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}{{bos_token}}{{ns.system_prompt}}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] -%}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']-%}
|
||||
{%- if not ns.is_first -%}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else -%}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
|
||||
{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else -%}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>')[-1] -%}
|
||||
{%- endif -%}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'tool' -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first -%}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false -%}
|
||||
{%- else -%}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- if add_generation_prompt and not ns.is_tool -%}{{'<|Assistant|><think>\n'}}{% if not enable_thinking %}{{- '</think>' -}}{% endif %}
|
||||
{%- endif %}
|
||||
|
|
@ -1,3 +1,71 @@
|
|||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '
|
||||
{% if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not thinking is defined -%}
|
||||
{%- if enable_thinking is defined -%}
|
||||
{%- set thinking = enable_thinking -%}
|
||||
{%- else -%}
|
||||
{%- set thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- if ns.is_first_sp -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + message['content'] -%}
|
||||
{%- set ns.is_first_sp = false -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '
|
||||
|
||||
' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user %}{{'<|Assistant|></think>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- else %}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{'<think>'}} {%- else %}{{'</think>'}}{%- endif %}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '</think>' in content %}{%- set content = content.split('</think>', 1)[1] -%}{%- endif %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endfor -%}{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{%- if not thinking %}{{'</think>'}}{%- else %}{{'<think>'}}{%- endif %}{% endif %}
|
||||
' + message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}{{ bos_token }}{{ ns.system_prompt }}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] -%}
|
||||
{%- if ns.is_last_user -%}{{'<|Assistant|></think>'}}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- if not ns.is_first -%}
|
||||
{%- if not message['content'] -%}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- else -%}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else -%}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and not message['tool_calls'] -%}
|
||||
{%- if ns.is_last_user -%}{{'<|Assistant|>'}}
|
||||
{%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{'<think>'}}
|
||||
{%- else -%}{{'</think>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool -%}{{message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else -%}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>', 1)[1] -%}
|
||||
{%- endif -%}{{content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'tool' -%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<|Assistant|>'}}
|
||||
{%- if not thinking -%}{{'</think>'}}
|
||||
{%- else -%}{{'<think>'}}
|
||||
{%- endif -%}
|
||||
{%- endif %}
|
||||
|
|
@ -46,7 +46,7 @@ Available functions as JSON spec:
|
|||
{%- if 'tool_calls' in message and message['tool_calls'] -%}
|
||||
{%- set tool = namespace(calls=[]) -%}
|
||||
{%- for call in message['tool_calls'] -%}
|
||||
{%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%}
|
||||
{%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments']|tojson + '}'] -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%}
|
||||
{%- endif -%}
|
||||
|
|
|
|||
|
|
@ -1,43 +1,43 @@
|
|||
{%- if tools -%}
|
||||
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if loop.first and messages[0]['role'] != 'system' -%}
|
||||
<|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
<|im_system|>system<|im_middle|>
|
||||
{%- elif message['role'] == 'user' -%}
|
||||
<|im_user|>user<|im_middle|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
<|im_system|>tool<|im_middle|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
||||
{%- if message['content'] -%}{{ message['content'] }}{%- endif -%}
|
||||
<|tool_calls_section_begin|>
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set func_name = tool_call['function']['name'] -%}
|
||||
{%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%}
|
||||
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|>
|
||||
{%- endfor -%}
|
||||
<|tool_calls_section_end|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
## Return of {{ message.tool_call_id }}\n{{ message['content'] }}
|
||||
{%- elif message['content'] is string -%}
|
||||
{{ message['content'] }}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{% for content in message['content'] -%}
|
||||
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||
{% else -%}
|
||||
{{ content['text'] }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if loop.first and messages[0]['role'] != 'system' -%}
|
||||
<|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
<|im_system|>system<|im_middle|>
|
||||
{%- elif message['role'] == 'user' -%}
|
||||
<|im_user|>user<|im_middle|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
<|im_system|>tool<|im_middle|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
||||
{%- if message['content'] -%}{{ message['content'] }}{%- endif -%}
|
||||
<|tool_calls_section_begin|>
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set func_name = tool_call['function']['name'] -%}
|
||||
{%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%}
|
||||
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|>
|
||||
{%- endfor -%}
|
||||
<|tool_calls_section_end|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
## Return of {{ message.tool_call_id }}\n{{ message['content'] }}
|
||||
{%- elif message['content'] is string -%}
|
||||
{{ message['content'] }}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{% for content in message['content'] -%}
|
||||
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||
{% else -%}
|
||||
{{ content['text'] }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- endif -%}
|
||||
|
|
|
|||
|
|
@ -86,22 +86,22 @@ Prior to generating the function calls, you should generate the reasoning for wh
|
|||
{%- set add_tool_id = false -%}
|
||||
{%- endif -%}
|
||||
{{- '<|assistant|>\n' -}}
|
||||
{%- if message['content'] is not none and message['content']|length > 0 -%}
|
||||
{%- if message['content'] is defined and message['content'] is not none and message['content']|length > 0 -%}
|
||||
{%- if message['content'] is not string and message['content'][0]['text'] is not none %}
|
||||
{{- message['content'][0]['text'] }}
|
||||
{%- else %}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['chosen'] is not none and message['chosen']|length > 0 -%}
|
||||
{%- elif message['chosen'] is defined and message['chosen'] is not none and message['chosen']|length > 0 -%}
|
||||
{{- message['chosen'][0] -}}
|
||||
{%- endif -%}
|
||||
{%- if add_thoughts and 'thought' in message and message['thought'] is not none -%}
|
||||
{{- '<thinking>' + message['thought'] + '</thinking>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] is not none and message['tool_calls']|length > 0 -%}
|
||||
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 -%}
|
||||
{{- '\n<tool_calls>[' -}}
|
||||
{%- for tool_call in message["tool_calls"] -%}
|
||||
{{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|string -}}
|
||||
{{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|tojson -}}
|
||||
{%- if add_tool_id == true -%}
|
||||
{{- ', "id": "' + tool_call['id'] + '"' -}}
|
||||
{%- endif -%}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,202 @@
|
|||
import argparse
|
||||
import json
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.terminator = "" # ← no newline
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[handler])
|
||||
logger = logging.getLogger("server-test-model")
|
||||
|
||||
|
||||
def run_query(url, messages, tools=None, stream=False, tool_choice=None):
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"max_tokens": 5000,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
if tool_choice:
|
||||
payload["tool_choice"] = tool_choice
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, stream=stream)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if e.response is not None:
|
||||
logger.info(f"Response error: {e} for {e.response.content}\n")
|
||||
else:
|
||||
logger.info(f"Error connecting to server: {e}\n")
|
||||
return None
|
||||
|
||||
full_content = ""
|
||||
reasoning_content = ""
|
||||
tool_calls = []
|
||||
|
||||
if stream:
|
||||
logger.info(f"--- Streaming response (Tools: {bool(tools)}) ---\n")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
if decoded_line.startswith("data: "):
|
||||
data_str = decoded_line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
# Content
|
||||
content_chunk = delta.get("content", "")
|
||||
if content_chunk:
|
||||
full_content += content_chunk
|
||||
logger.info(content_chunk)
|
||||
|
||||
# Reasoning
|
||||
reasoning_chunk = delta.get("reasoning_content", "")
|
||||
if reasoning_chunk:
|
||||
reasoning_content += reasoning_chunk
|
||||
logger.info(f"\x1B[3m{reasoning_chunk}\x1B[0m")
|
||||
|
||||
# Tool calls
|
||||
if "tool_calls" in delta:
|
||||
for tc in delta["tool_calls"]:
|
||||
index = tc.get("index")
|
||||
if index is not None:
|
||||
while len(tool_calls) <= index:
|
||||
# Using "function" as type default but could be flexible
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if "id" in tc:
|
||||
tool_calls[index]["id"] += tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_calls[index]["function"][
|
||||
"name"
|
||||
] += tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
tool_calls[index]["function"][
|
||||
"arguments"
|
||||
] += tc["function"]["arguments"]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.info(f"Failed to decode JSON: {data_str}\n")
|
||||
logger.info("\n--- End of Stream ---\n")
|
||||
else:
|
||||
logger.info(f"--- Non-streaming response (Tools: {bool(tools)}) ---\n")
|
||||
data = response.json()
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
message = data["choices"][0].get("message", {})
|
||||
full_content = message.get("content", "")
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
logger.info(full_content)
|
||||
logger.info("--- End of Response ---\n")
|
||||
|
||||
return {
|
||||
"content": full_content,
|
||||
"reasoning_content": reasoning_content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
def test_chat(url, stream):
|
||||
logger.info(f"\n=== Testing Chat (Stream={stream}) ===\n")
|
||||
messages = [{"role": "user", "content": "What is the capital of France?"}]
|
||||
result = run_query(url, messages, stream=stream)
|
||||
|
||||
if result:
|
||||
if result["content"]:
|
||||
logger.info("PASS: Output received.\n")
|
||||
else:
|
||||
logger.info("WARN: No content received (valid if strict tool call, but unexpected here).\n")
|
||||
|
||||
if result.get("reasoning_content"):
|
||||
logger.info(f"INFO: Reasoning content detected ({len(result['reasoning_content'])} chars).\n")
|
||||
else:
|
||||
logger.info("INFO: No reasoning content detected (Standard model behavior).\n")
|
||||
else:
|
||||
logger.info("FAIL: No result.\n")
|
||||
|
||||
|
||||
def test_tool_call(url, stream):
|
||||
logger.info(f"\n=== Testing Tool Call (Stream={stream}) ===\n")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in London? Please use the get_weather tool.",
|
||||
}
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = run_query(url, messages, tools=tools, tool_choice="auto", stream=stream)
|
||||
|
||||
if result:
|
||||
tcs = result.get("tool_calls")
|
||||
if tcs and len(tcs) > 0:
|
||||
logger.info("PASS: Tool calls detected.")
|
||||
for tc in tcs:
|
||||
func = tc.get("function", {})
|
||||
logger.info(f" Tool: {func.get('name')}, Args: {func.get('arguments')}\n")
|
||||
else:
|
||||
logger.info(f"FAIL: No tool calls. Content: {result['content']}\n")
|
||||
|
||||
if result.get("reasoning_content"):
|
||||
logger.info(
|
||||
f"INFO: Reasoning content detected during tool call ({len(result['reasoning_content'])} chars).\n"
|
||||
)
|
||||
else:
|
||||
logger.info("FAIL: Query failed.\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test llama-server functionality.")
|
||||
parser.add_argument("--host", default="localhost", help="Server host")
|
||||
parser.add_argument("--port", default=8080, type=int, help="Server port")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_url = f"http://{args.host}:{args.port}/v1/chat/completions"
|
||||
logger.info(f"Testing server at {base_url}\n")
|
||||
|
||||
# Non-streaming tests
|
||||
test_chat(base_url, stream=False)
|
||||
test_tool_call(base_url, stream=False)
|
||||
|
||||
# Streaming tests
|
||||
test_chat(base_url, stream=True)
|
||||
test_tool_call(base_url, stream=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|||
clear();
|
||||
split_reset();
|
||||
|
||||
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
|
||||
|
||||
auto udata = std::make_shared<llama_ubatch::data_t>();
|
||||
|
||||
udata->token .resize(n_tokens);
|
||||
udata->embd .clear();
|
||||
udata->pos .resize(n_tokens);
|
||||
udata->pos .resize(n_pos_all);
|
||||
udata->n_seq_id .resize(n_tokens);
|
||||
udata->seq_id .resize(n_tokens);
|
||||
udata->seq_id_unq.resize(0);
|
||||
|
|
|
|||
|
|
@ -1039,11 +1039,15 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
|
|||
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
||||
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
||||
|
||||
if (n_adapters != loras->size()) {
|
||||
return false;
|
||||
}
|
||||
// Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison.
|
||||
size_t n_non_zero = 0;
|
||||
|
||||
for (size_t i = 0; i < n_adapters; i ++) {
|
||||
if (scales[i] == 0.0f) {
|
||||
continue;
|
||||
}
|
||||
n_non_zero++;
|
||||
|
||||
auto it = loras->find(adapters[i]);
|
||||
|
||||
if (it == loras->end() || it->second != scales[i]) {
|
||||
|
|
@ -1051,6 +1055,10 @@ bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_
|
|||
}
|
||||
}
|
||||
|
||||
if (n_non_zero != loras->size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1760,8 +1760,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|||
io.write(&pos, sizeof(pos));
|
||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
||||
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
const llama_kv_cell_ext ext = cells.ext_get(i);
|
||||
io.write(&ext, sizeof(ext));
|
||||
}
|
||||
|
||||
for (const auto & seq_id : seq_ids) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
|
|
@ -1895,6 +1897,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
return false;
|
||||
}
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
|
||||
ubatch.pos[i + ubatch.n_tokens] = ext.y;
|
||||
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
|
||||
}
|
||||
|
||||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
|
|
|
|||
|
|
@ -187,11 +187,11 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
|||
# llama_build_and_test(test-double-float.cpp) # SLOW
|
||||
endif()
|
||||
|
||||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-jinja.cpp)
|
||||
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
|
||||
llama_build_and_test(test-chat-auto-parser.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
|
|
@ -201,6 +201,7 @@ llama_build_and_test(
|
|||
peg-parser/test-gbnf-generation.cpp
|
||||
peg-parser/test-json-parser.cpp
|
||||
peg-parser/test-json-serialization.cpp
|
||||
peg-parser/test-python-dict-parser.cpp
|
||||
peg-parser/test-unicode.cpp
|
||||
peg-parser/tests.h
|
||||
)
|
||||
|
|
@ -279,3 +280,5 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
|||
|
||||
llama_build_and_test(test-alloc.cpp)
|
||||
target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "peg-parser.h"
|
||||
#include "tests.h"
|
||||
|
||||
void test_basic(testing & t) {
|
||||
|
|
@ -450,5 +451,21 @@ void test_basic(testing & t) {
|
|||
|
||||
t.assert_equal("result_is_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Test markers
|
||||
t.test("marker", [](testing &t) {
|
||||
auto bracket_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.marker();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx_square("[marker]", false);
|
||||
common_peg_parse_context ctx_sharp("<marker>", false);
|
||||
|
||||
auto result_square = bracket_parser.parse(ctx_square);
|
||||
auto result_sharp = bracket_parser.parse(ctx_sharp);
|
||||
|
||||
t.assert_true("result_square_is_success", result_square.success());
|
||||
t.assert_true("result_sharp_is_success", result_sharp.success());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,318 @@
|
|||
#include "tests.h"
|
||||
|
||||
void test_python_dict_parser(testing &t) {
|
||||
// Test parsing a simple Python dict object with single quotes
|
||||
t.test("simple Python dict object parsing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'name': 'test', 'value': 42, 'flag': True}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing a Python array with mixed types
|
||||
t.test("Python array with mixed types", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "[1, 'hello', True, None, 3.14]";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing nested Python dict with objects and arrays
|
||||
t.test("nested Python dict with objects and arrays", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input =
|
||||
"{'users': [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}], 'count': 2, 'metadata': {'version': '1.0', 'tags': ['admin', 'user']}}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing Python dict with escaped single quotes
|
||||
t.test("Python dict with escaped single quotes", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'message': 'It\\'s working!'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing Python dict with double quotes inside single quotes
|
||||
t.test("Python dict with double quotes inside single quotes", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'quote': 'He said \"Hello\"'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test the example from the requirements
|
||||
t.test("complex Python dict example from requirements", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{ 'obj' : { 'something': 1, 'other \"something\"' : 'foo\\'s bar' } }";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete object
|
||||
t.test("need_more_input() parsing - incomplete object", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'name': 'test', 'value': ";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete single-quoted string
|
||||
t.test("need_more_input() parsing - incomplete single-quoted string", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'name': 'test";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test unicode in Python dict strings
|
||||
t.test("unicode in Python dict strings", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'message': 'Hello, 世界!'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test Python dict with unicode escapes
|
||||
t.test("Python dict with unicode escapes", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{'unicode': 'Hello\\u0041'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test that Python parser accepts double-quoted strings too
|
||||
t.test("Python parser accepts double-quoted strings", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{\"name\": \"test\"}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test Python parser with mixed quote styles
|
||||
t.test("Python parser with mixed quote styles", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
std::string input = "{\"name\": 'test', 'value': \"hello\"}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test Python True/False/None
|
||||
t.test("Python True/False/None", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); });
|
||||
|
||||
t.test("True", [&](testing &t) {
|
||||
std::string input = "True";
|
||||
common_peg_parse_context ctx(input);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("False", [&](testing &t) {
|
||||
std::string input = "False";
|
||||
common_peg_parse_context ctx(input);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("None", [&](testing &t) {
|
||||
std::string input = "None";
|
||||
common_peg_parse_context ctx(input);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("rejects JSON-style true/false/null", [&](testing &t) {
|
||||
for (const auto & kw : {"true", "false", "null"}) {
|
||||
std::string input = kw;
|
||||
common_peg_parse_context ctx(input);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true(std::string("rejects ") + kw, result.fail());
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Test single-quoted string content parser directly
|
||||
t.test("single-quoted string content parser", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.sequence({ p.literal("'"), p.single_quoted_string_content(), p.literal("'"), p.space() });
|
||||
});
|
||||
|
||||
t.test("simple string", [&](testing &t) {
|
||||
std::string input = "'hello'";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("string with escaped single quote", [&](testing &t) {
|
||||
std::string input = "'it\\'s'";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("string with double quotes", [&](testing &t) {
|
||||
std::string input = "'say \"hello\"'";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("incomplete string", [&](testing &t) {
|
||||
std::string input = "'hello";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("need_more_input", result.need_more_input());
|
||||
});
|
||||
});
|
||||
|
||||
// Test json() with pre-registered flexible json-string rule (python dict support)
|
||||
t.test("json() parser with flexible json-string rule", [](testing &t) {
|
||||
t.test("json() rejects single quotes by default", [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.json();
|
||||
});
|
||||
|
||||
std::string input = "{'name': 'test'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("fail", result.fail());
|
||||
});
|
||||
|
||||
t.test("json() accepts single quotes with pre-registered flexible json-string rule", [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
// Pre-register json-string rule with both quote styles
|
||||
p.rule("json-string", [&]() {
|
||||
return p.choice({ p.double_quoted_string(), p.single_quoted_string() });
|
||||
});
|
||||
return p.json();
|
||||
});
|
||||
|
||||
std::string input = "{'name': 'test'}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("json() still accepts double quotes with flexible json-string rule", [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("json-string", [&]() {
|
||||
return p.choice({ p.double_quoted_string(), p.single_quoted_string() });
|
||||
});
|
||||
return p.json();
|
||||
});
|
||||
|
||||
std::string input = "{\"name\": \"test\"}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("json() accepts mixed quote styles with flexible json-string rule", [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("json-string", [&]() {
|
||||
return p.choice({ p.double_quoted_string(), p.single_quoted_string() });
|
||||
});
|
||||
return p.json();
|
||||
});
|
||||
|
||||
std::string input = "{\"name\": 'test', 'value': \"hello\"}";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
|
||||
t.test("complex nested structure with flexible json-string rule", [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("json-string", [&]() {
|
||||
return p.choice({ p.double_quoted_string(), p.single_quoted_string() });
|
||||
});
|
||||
return p.json();
|
||||
});
|
||||
|
||||
std::string input = "{ 'obj' : { 'something': 1, 'other \"something\"' : 'foo\\'s bar' } }";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
t.assert_equal("end", input.size(), result.end);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -22,3 +22,4 @@ void test_json_parser(testing &t);
|
|||
void test_gbnf_generation(testing &t);
|
||||
void test_unicode(testing &t);
|
||||
void test_json_serialization(testing &t);
|
||||
void test_python_dict_parser(testing &t);
|
||||
|
|
|
|||
|
|
@ -7663,6 +7663,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
|
||||
// long token (n_t > 32, exercises the long_token kernel path)
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7817,6 +7820,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 1, 64, 256, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 6, 4096, 5120, {1, 1}, {1, 1}));
|
||||
|
||||
#if 0
|
||||
// test the mat-mat path for Metal
|
||||
for (int k = 1; k < 512; ++k) {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,617 +0,0 @@
|
|||
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
||||
//
|
||||
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
||||
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
||||
//
|
||||
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
//
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
template <class T>
|
||||
static void assert_equals(const std::string_view label, const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << label << std::endl;
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
assert_equals("", expected, actual);
|
||||
}
|
||||
static void assert_equals(const char * expected, const std::string & actual) {
|
||||
return assert_equals<std::string>(expected, actual);
|
||||
}
|
||||
|
||||
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
|
||||
try {
|
||||
fn();
|
||||
} catch (const std::exception & e) {
|
||||
if (expected_exception_pattern.empty()) {
|
||||
return;
|
||||
}
|
||||
std::regex expected_exception_regex(expected_exception_pattern);
|
||||
std::string actual_message = e.what();
|
||||
if (std::regex_search(actual_message, expected_exception_regex)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
|
||||
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
|
||||
}
|
||||
throw std::runtime_error("Exception was expected but not thrown");
|
||||
}
|
||||
|
||||
static void test_reasoning() {
|
||||
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = true;
|
||||
params.thinking_forced_open = true;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<think>Cogito</think>", builder.result().content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
const std::string variant("content_only_inline_think");
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = false;
|
||||
const std::string input = "<think>Pense</think>Bonjour";
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("Pense"), msg.reasoning_content);
|
||||
assert_equals(variant, std::string("Bonjour"), msg.content);
|
||||
}
|
||||
{
|
||||
const std::string variant("llama_3_inline_think");
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = false;
|
||||
const std::string input = "<think>Plan</think>Réponse";
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("Plan"), msg.reasoning_content);
|
||||
assert_equals(variant, std::string("Réponse"), msg.content);
|
||||
}
|
||||
// Test DeepSeek V3.1 parsing - reasoning content followed by "</think>" and then regular content
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("deepseek_v3_1_reasoning_format_deepseek");
|
||||
common_chat_msg_parser builder("REASONING</think>ok", /* is_partial= */ false, params);
|
||||
assert_equals(variant, true, builder.try_parse_reasoning("<think>", "</think>"));
|
||||
assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content);
|
||||
assert_equals(variant, std::string("ok"), builder.consume_rest());
|
||||
}
|
||||
// Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "</think>" and then regular content
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("deepseek_v3_1_reasoning_format_none");
|
||||
const std::string input = "REASONING</think>ok";
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("REASONING</think>ok"), msg.content);
|
||||
assert_equals(variant, std::string(""), msg.reasoning_content);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
|
||||
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
|
||||
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
|
||||
};
|
||||
|
||||
test_throws("Hello, world!", "abc", "^abc$");
|
||||
test_throws("Hello, world!", "e", "^e$");
|
||||
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
builder.consume_regex(common_regex("Hello"));
|
||||
assert_equals(", world!", builder.consume_rest());
|
||||
}
|
||||
|
||||
{
|
||||
// When in non partial mode, we can say whether the regex was consumed or not.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
|
||||
assert_equals(true, res.has_value());
|
||||
// Verify captures
|
||||
assert_equals<size_t>(2, res->groups.size());
|
||||
assert_equals("Hell", builder.str(res->groups[0]));
|
||||
assert_equals("el", builder.str(res->groups[1]));
|
||||
// Verify position is after the match
|
||||
assert_equals<size_t>(4, builder.pos());
|
||||
assert_equals("o,", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
|
||||
assert_throws([&]() {
|
||||
builder.try_consume_regex(common_regex("Hello, world!"));
|
||||
}, "^Hello, world!$");
|
||||
}
|
||||
|
||||
// Now regardless of the mode, we can tell these aren't a match.
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
|
||||
}
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_literal("Oh"));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::string> barely_healable_jsons = {
|
||||
"{",
|
||||
"{\"",
|
||||
"{\"\\",
|
||||
"{\"n",
|
||||
"{\"name\"",
|
||||
"{\"name\":",
|
||||
"{\"name\":\"",
|
||||
"{\"name\":\"\\",
|
||||
"{\"name\":\"python",
|
||||
"{\"name\":\"python\\",
|
||||
"{\",",
|
||||
"{\":",
|
||||
"{\"[",
|
||||
"{\"]",
|
||||
"{\"{",
|
||||
"{\"}",
|
||||
"{\"1",
|
||||
"{\"name\":\",",
|
||||
"{\"name\":\":",
|
||||
"{\"name\":\"[",
|
||||
"{\"name\":\"]",
|
||||
"{\"name\":\"{",
|
||||
"{\"name\":\"}",
|
||||
"{\"name\":\"1",
|
||||
};
|
||||
|
||||
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
|
||||
common_chat_msg_parser builder(input, is_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
|
||||
}
|
||||
|
||||
static void test_deepseek_v3_1_tool_calls() {
|
||||
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
|
||||
// variant: happy path for when it works as the model card says it should
|
||||
const std::string variant("simple");
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string input = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, msg.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name);
|
||||
// JSON arguments are dumped without spaces
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), msg.tool_calls[0].arguments);
|
||||
assert_equals(variant, std::string(""), msg.content);
|
||||
assert_equals(variant, std::string(""), msg.reasoning_content);
|
||||
|
||||
// variant: simple + thinking open
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("simple_thinking");
|
||||
const std::string in = "REASONING</think><|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
|
||||
assert_equals(variant, std::string(""), m.content);
|
||||
assert_equals(variant, std::string("REASONING"), m.reasoning_content);
|
||||
}
|
||||
// variant: simple + multiple tool calls
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("simple_multiple_tool_calls");
|
||||
const std::string in = "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 2, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments);
|
||||
assert_equals(variant, std::string("get_weather"), m.tool_calls[1].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[1].arguments);
|
||||
assert_equals(variant, std::string("CONTENT"), m.content);
|
||||
assert_equals(variant, std::string(""), m.reasoning_content);
|
||||
}
|
||||
|
||||
|
||||
// variant: thinking forced open + tool call in reasoning content
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING</think><|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
|
||||
assert_equals(variant, std::string(""), m.content);
|
||||
assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING"), m.reasoning_content);
|
||||
}
|
||||
|
||||
// variant: thinking forced open + tool call in reasoning content + no closing think + not partial
|
||||
// This is a bit of a fine tuning issue on the model's part IMO. It really should not be attempting
|
||||
// to make tool calls in reasoning content according to the model card, but it does sometimes, so
|
||||
// add the reasoning content as regular content and parse the tool calls.
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals(variant, std::string("REASONING"), m.content);
|
||||
assert_equals(variant, std::string(""), m.reasoning_content);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
|
||||
}
|
||||
|
||||
// variant: thinking forced open + tool call in reasoning content + no closing think + partial
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, /* is_partial= */ true, params);
|
||||
assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"), m.reasoning_content);
|
||||
assert_equals(variant, std::string(""), m.content);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
}
|
||||
|
||||
// variant: thinking not forced open + reasoning + regular content + no tool calls
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls");
|
||||
const std::string in = "REASONING</think>CONTENT";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("CONTENT"), m.content);
|
||||
assert_equals(variant, std::string("REASONING"), m.reasoning_content);
|
||||
}
|
||||
// variant: thinking not forced open + missing reasoning + no tool calls
|
||||
{
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls");
|
||||
const std::string in = "CONTENT";
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("CONTENT"), m.content);
|
||||
assert_equals(variant, std::string(""), m.reasoning_content);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
|
||||
common_chat_msg_parser builder(input, parse_as_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, js->value.dump());
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args_no_args() {
|
||||
// Normal JSON, nothing to heal, nothing to dump
|
||||
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
|
||||
// Full json is args
|
||||
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
|
||||
|
||||
// If the arguments are further down, don't heal partial content.
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{"arguments"}}, {}, "{}");
|
||||
}
|
||||
// But heal content that isn't partial.
|
||||
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args() {
|
||||
|
||||
// Partial content.
|
||||
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
|
||||
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
|
||||
test("{\"content\": ", true, {}, {{"content"}}, "{}");
|
||||
|
||||
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
|
||||
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{}}, {}, src);
|
||||
}
|
||||
|
||||
// Full JSON w/ args
|
||||
for (auto parse_as_partial : {true, false}) {
|
||||
test_with_args(
|
||||
R"({"name": "python", "args": {"arg1": 1}})",
|
||||
R"({"name":"python","args":"{\"arg1\":1}"})",
|
||||
parse_as_partial,
|
||||
/* is_partial= */ false
|
||||
);
|
||||
}
|
||||
|
||||
// Partial JSON w/ partial args
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {")",
|
||||
R"({"foo":"bar","args":"{\""})"
|
||||
);
|
||||
// Partial args broken in object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"ar)",
|
||||
R"({"foo":"bar","args":"{\"ar"})"
|
||||
);
|
||||
// Partial args broken after object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\""})"
|
||||
);
|
||||
// Partial args broken before object value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1":)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken before object value (space)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that may not be complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1 )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":1"})"
|
||||
);
|
||||
// Partial args broken in object value that is incomplete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": ")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\""})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
|
||||
);
|
||||
// Partial args broken on array opening
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is incomplete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1 )",
|
||||
R"({"foo":"bar","args":"[1"})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": ["1")",
|
||||
R"({"foo":"bar","args":"[\"1\""})"
|
||||
);
|
||||
// Partial args broken after array value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1,)",
|
||||
R"({"foo":"bar","args":"[1,"})"
|
||||
);
|
||||
// Partial args broken on nested array
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
|
||||
// Unicode tests
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud8)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud80)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_to(100); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_back(1); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(8);
|
||||
assert_equals<size_t>(8, builder.pos());
|
||||
builder.move_back(1);
|
||||
assert_equals<size_t>(7, builder.pos());
|
||||
assert_equals("world!", builder.consume_rest());
|
||||
|
||||
builder.move_to(0);
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
assert_throws([&]() { builder.finish(); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
builder.finish();
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
assert_equals<size_t>(builder.input().size(), builder.pos());
|
||||
builder.finish();
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_positions();
|
||||
test_json_with_dumped_args_no_args();
|
||||
test_json_with_dumped_args();
|
||||
test_reasoning();
|
||||
test_regex();
|
||||
test_deepseek_v3_1_tool_calls();
|
||||
std::cout << "All tests passed!\n";
|
||||
return 0;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,4 +1,5 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
|
|
@ -21,17 +22,16 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
int main_automated_tests(void);
|
||||
static int main_automated_tests(void);
|
||||
|
||||
void run_multiple(std::string dir_path, bool stop_on_first_failure, json input, bool use_common = false);
|
||||
void run_single(std::string contents, json input, bool use_common = false, const std::string & output_path = "");
|
||||
static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
|
||||
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
|
||||
|
||||
|
||||
|
||||
std::string HELP = R"(
|
||||
static std::string HELP = R"(
|
||||
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
|
||||
Options:
|
||||
-h, --help Show this help message and exit.
|
||||
--with-tools Add a tool and a tool call to the default JSON input
|
||||
--json <path> Path to the JSON input file.
|
||||
--stop-on-first-fail Stop testing on the first failure (default: false).
|
||||
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
|
||||
|
|
@ -41,7 +41,7 @@ If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
|
|||
If PATH_TO_TEMPLATE is omitted, runs automated tests (default CI mode).
|
||||
)";
|
||||
|
||||
std::string DEFAULT_JSON = R"({
|
||||
static std::string DEFAULT_JSON = R"({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -57,12 +57,65 @@ std::string DEFAULT_JSON = R"({
|
|||
"add_generation_prompt": true
|
||||
})";
|
||||
|
||||
static std::string DEFAULT_JSON_WITH_TOOLS = R"({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I am fine, thank you!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Call a tool!"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call00001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test",
|
||||
"arguments": { "arg": "hello" }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test",
|
||||
"description": "Test",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["arg"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"add_generation_prompt": true
|
||||
})";
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::vector<std::string> args(argv, argv + argc);
|
||||
|
||||
std::string tmpl_path;
|
||||
std::string json_path;
|
||||
std::string output_path;
|
||||
std::string & json_to_use = DEFAULT_JSON;
|
||||
bool stop_on_first_fail = false;
|
||||
bool use_common = true;
|
||||
|
||||
|
|
@ -70,9 +123,12 @@ int main(int argc, char ** argv) {
|
|||
if (args[i] == "--help" || args[i] == "-h") {
|
||||
std::cout << HELP << "\n";
|
||||
return 0;
|
||||
} else if (args[i] == "--json" && i + 1 < args.size()) {
|
||||
}
|
||||
if (args[i] == "--json" && i + 1 < args.size()) {
|
||||
json_path = args[i + 1];
|
||||
i++;
|
||||
} else if (args[i] == "--with-tools") {
|
||||
json_to_use = DEFAULT_JSON_WITH_TOOLS;
|
||||
} else if (args[i] == "--stop-on-first-fail") {
|
||||
stop_on_first_fail = true;
|
||||
} else if (args[i] == "--output" && i + 1 < args.size()) {
|
||||
|
|
@ -105,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
std::istreambuf_iterator<char>());
|
||||
input_json = json::parse(content);
|
||||
} else {
|
||||
input_json = json::parse(DEFAULT_JSON);
|
||||
input_json = json::parse(json_to_use);
|
||||
}
|
||||
|
||||
std::filesystem::path p(tmpl_path);
|
||||
|
|
@ -125,7 +181,7 @@ int main(int argc, char ** argv) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
void run_multiple(std::string dir_path, bool stop_on_first_fail, json input, bool use_common) {
|
||||
void run_multiple(const std::string& dir_path, bool stop_on_first_fail, const json& input, bool use_common) {
|
||||
std::vector<std::string> failed_tests;
|
||||
|
||||
// list all files in models/templates/ and run each
|
||||
|
|
@ -180,7 +236,7 @@ static std::string format_using_common(
|
|||
common_chat_templates_inputs inputs;
|
||||
inputs.use_jinja = true;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.tools = std::move(tools);
|
||||
inputs.add_generation_prompt = true;
|
||||
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
|
||||
output = normalize_newlines(output);
|
||||
|
|
@ -209,7 +265,7 @@ static jinja::value_string format_using_direct_engine(
|
|||
|
||||
jinja::runtime runtime(ctx);
|
||||
const jinja::value results = runtime.execute(ast);
|
||||
auto parts = runtime.gather_string_parts(results);
|
||||
auto parts = jinja::runtime::gather_string_parts(results);
|
||||
|
||||
std::cout << "\n=== RESULTS ===\n";
|
||||
for (const auto & part : parts->as_string().parts) {
|
||||
|
|
@ -220,7 +276,7 @@ static jinja::value_string format_using_direct_engine(
|
|||
}
|
||||
|
||||
|
||||
void run_single(std::string contents, json input, bool use_common, const std::string & output_path) {
|
||||
void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
|
||||
jinja::enable_debug(true);
|
||||
|
||||
jinja::value_string output_parts;
|
||||
|
|
@ -560,7 +616,7 @@ int main_automated_tests(void) {
|
|||
supported_tmpl.resize(res);
|
||||
res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
|
||||
std::cout << "Built-in chat templates:\n";
|
||||
for (auto tmpl : supported_tmpl) {
|
||||
for (const auto *tmpl : supported_tmpl) {
|
||||
std::cout << " " << tmpl << "\n";
|
||||
}
|
||||
|
||||
|
|
@ -592,6 +648,7 @@ int main_automated_tests(void) {
|
|||
}
|
||||
|
||||
std::vector<common_chat_msg> messages;
|
||||
messages.reserve(conversation.size());
|
||||
for (const auto & msg : conversation) {
|
||||
messages.push_back(simple_msg(msg.role, msg.content));
|
||||
}
|
||||
|
|
@ -622,58 +679,6 @@ int main_automated_tests(void) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: llama_chat_format_single will be deprecated, remove these tests later
|
||||
|
||||
// test llama_chat_format_single for system message
|
||||
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
|
||||
std::vector<common_chat_msg> chat2;
|
||||
auto sys_msg = simple_msg("system", "You are a helpful assistant");
|
||||
|
||||
auto fmt_sys = [&](std::string tmpl_str) {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
|
||||
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
|
||||
std::cout << "fmt_sys(" << tmpl_str << ") : " << output << "\n";
|
||||
std::cout << "-------------------------\n";
|
||||
return output;
|
||||
};
|
||||
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
|
||||
assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
|
||||
assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
|
||||
assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
|
||||
assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
|
||||
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
|
||||
assert(fmt_sys("llama2-sys") == "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\n");
|
||||
assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
|
||||
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
|
||||
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
|
||||
assert(fmt_sys("gigachat") == "<s>You are a helpful assistant<|message_sep|>");
|
||||
|
||||
|
||||
// test llama_chat_format_single for user message
|
||||
std::cout << "\n\n=== llama_chat_format_single (user message) ===\n\n";
|
||||
chat2.push_back(simple_msg("system", "You are a helpful assistant"));
|
||||
chat2.push_back(simple_msg("user", "Hello"));
|
||||
chat2.push_back(simple_msg("assistant", "I am assistant"));
|
||||
auto new_msg = simple_msg("user", "How are you");
|
||||
|
||||
auto fmt_single = [&](const std::string & tmpl_str) {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
|
||||
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
|
||||
std::cout << "fmt_single(" << tmpl_str << ") : " << output << "\n";
|
||||
std::cout << "-------------------------\n";
|
||||
return output;
|
||||
};
|
||||
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
|
||||
assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
|
||||
assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
|
||||
assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
|
||||
assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]");
|
||||
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
|
||||
assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
|
||||
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
|
||||
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||
// assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
||||
|
||||
std::cout << "\nOK: All tests passed successfully.\n";
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
5791
tests/test-chat.cpp
5791
tests/test-chat.cpp
File diff suppressed because it is too large
Load Diff
|
|
@ -1340,6 +1340,26 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"description only (no type) treated as unconstrained",
|
||||
R"""({"description": "The 0-based index of the last line to be retrieved (inclusive). If None, read until the end of the file."})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= value
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"literal string with escapes",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ int main(int argc, char *argv[]) {
|
|||
t.test("json", test_json_parser);
|
||||
t.test("gbnf", test_gbnf_generation);
|
||||
t.test("serialization", test_json_serialization);
|
||||
t.test("python-dict", test_python_dict_parser);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ else()
|
|||
add_subdirectory(server)
|
||||
endif()
|
||||
add_subdirectory(tokenize)
|
||||
add_subdirectory(parser)
|
||||
add_subdirectory(tts)
|
||||
add_subdirectory(mtmd)
|
||||
if (GGML_RPC)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "arg.h"
|
||||
#include "console.h"
|
||||
|
|
@ -191,7 +192,8 @@ struct cli_context {
|
|||
inputs.use_jinja = chat_params.use_jinja;
|
||||
inputs.parallel_tool_calls = false;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.enable_thinking = chat_params.enable_thinking;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
inputs.enable_thinking = common_chat_templates_support_enable_thinking(chat_params.tmpls.get());
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
# this tool is disabled on Windows when building with shared libraries because it uses internal functions not exported with LLAMA_API
|
||||
set(TARGET llama-debug-template-parser)
|
||||
add_executable(${TARGET} debug-template-parser.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(TARGET llama-template-analysis)
|
||||
add_executable(${TARGET} template-analysis.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
|
@ -0,0 +1,452 @@
|
|||
#include "../src/llama-grammar.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "gguf.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum class output_mode {
|
||||
ANALYSIS, // Only output analysis results (default)
|
||||
TEMPLATE, // Only output rendered template
|
||||
BOTH // Output both
|
||||
};
|
||||
|
||||
enum class input_message_type {
|
||||
NONE, // Don't render any message scenarios (only analysis)
|
||||
CONTENT_ONLY, // Simple assistant message with content
|
||||
REASONING_CONTENT, // Message with reasoning_content + content
|
||||
TOOL_CALL_ONLY, // Message with tool_calls only
|
||||
CONTENT_TOOL_CALL, // Message with content + tool_calls
|
||||
REASONING_TOOL_CALL, // Message with reasoning_content + tool_calls
|
||||
CONTENT_FAKE_TOOL_CALL, // Message with content but no actual tool_calls (for testing)
|
||||
ALL // Render all scenarios
|
||||
};
|
||||
|
||||
struct debug_options {
|
||||
std::string template_path;
|
||||
bool with_tools = true;
|
||||
bool generation_prompt = true;
|
||||
bool enable_reasoning = true;
|
||||
bool debug_jinja = false;
|
||||
bool force_tool_call = false;
|
||||
output_mode mode = output_mode::BOTH;
|
||||
input_message_type input_message = input_message_type::NONE;
|
||||
};
|
||||
|
||||
static std::string read_file(const std::string & path) {
|
||||
std::ifstream fin(path, std::ios::binary);
|
||||
if (!fin.is_open()) {
|
||||
throw std::runtime_error("Could not open file: " + path);
|
||||
}
|
||||
std::ostringstream buf;
|
||||
buf << fin.rdbuf();
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
static std::string read_gguf_chat_template(const std::string & path) {
|
||||
struct gguf_init_params params = { /*no_alloc =*/true, // We only need metadata, not tensor data
|
||||
/*ctx=*/nullptr };
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params);
|
||||
if (ctx == nullptr) {
|
||||
throw std::runtime_error("Could not open GGUF file: " + path);
|
||||
}
|
||||
|
||||
const char * key = "tokenizer.chat_template";
|
||||
int64_t key_id = gguf_find_key(ctx, key);
|
||||
|
||||
if (key_id == -1) {
|
||||
gguf_free(ctx);
|
||||
throw std::runtime_error("GGUF file does not contain chat template key: " + std::string(key));
|
||||
}
|
||||
|
||||
const char * template_str = gguf_get_val_str(ctx, key_id);
|
||||
if (template_str == nullptr) {
|
||||
gguf_free(ctx);
|
||||
throw std::runtime_error("GGUF file contains chat template key but value is null");
|
||||
}
|
||||
|
||||
std::string result = template_str;
|
||||
gguf_free(ctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
static void print_usage(const char * program_name) {
|
||||
LOG_ERR("Usage: %s <template_or_gguf_path> [options]\n", program_name);
|
||||
LOG_ERR("\nOptions:\n");
|
||||
LOG_ERR(" --no-tools Disable tool definitions\n");
|
||||
LOG_ERR(" --force-tool-call Set tool calls to forced\n");
|
||||
LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n");
|
||||
LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n");
|
||||
LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n");
|
||||
LOG_ERR(" --debug-jinja Enable Jinja fine-grained debug\n");
|
||||
LOG_ERR(" --input-message=TYPE Message type to render:\n");
|
||||
LOG_ERR(" content_only, reasoning_content, tool_call_only,\n");
|
||||
LOG_ERR(" content_tool_call, reasoning_tool_call,\n");
|
||||
LOG_ERR(" content_fake_tool_call, all\n");
|
||||
LOG_ERR("\nExamples:\n");
|
||||
LOG_ERR(" %s template.jinja --input-message=all --generation-prompt=1\n", program_name);
|
||||
LOG_ERR(" %s template.jinja --output=template --input-message=tool_call_only\n", program_name);
|
||||
}
|
||||
|
||||
static bool parse_bool_option(const std::string & value) {
|
||||
return value == "1" || value == "true" || value == "yes";
|
||||
}
|
||||
|
||||
static bool parse_options(int argc, char ** argv, debug_options & opts) {
|
||||
if (argc < 2) {
|
||||
print_usage(argv[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
opts.template_path = argv[1];
|
||||
|
||||
for (int i = 2; i < argc; ++i) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "--force-tool-call") {
|
||||
opts.force_tool_call = true;
|
||||
} else if (arg == "--debug-jinja") {
|
||||
opts.debug_jinja = true;
|
||||
} else if (arg == "--no-tools") {
|
||||
opts.with_tools = false;
|
||||
} else if (arg.rfind("--generation-prompt=", 0) == 0) {
|
||||
opts.generation_prompt = parse_bool_option(arg.substr(20));
|
||||
} else if (arg.rfind("--enable-reasoning=", 0) == 0) {
|
||||
opts.enable_reasoning = parse_bool_option(arg.substr(19));
|
||||
} else if (arg.rfind("--output=", 0) == 0) {
|
||||
std::string mode = arg.substr(9);
|
||||
if (mode == "analysis") {
|
||||
opts.mode = output_mode::ANALYSIS;
|
||||
} else if (mode == "template") {
|
||||
opts.mode = output_mode::TEMPLATE;
|
||||
} else if (mode == "both") {
|
||||
opts.mode = output_mode::BOTH;
|
||||
} else {
|
||||
LOG_ERR("Unknown output mode: %s\n", mode.c_str());
|
||||
return false;
|
||||
}
|
||||
} else if (arg.rfind("--input-message=", 0) == 0) {
|
||||
std::string type = arg.substr(16);
|
||||
if (type == "content_only") {
|
||||
opts.input_message = input_message_type::CONTENT_ONLY;
|
||||
} else if (type == "reasoning_content") {
|
||||
opts.input_message = input_message_type::REASONING_CONTENT;
|
||||
} else if (type == "tool_call_only") {
|
||||
opts.input_message = input_message_type::TOOL_CALL_ONLY;
|
||||
} else if (type == "content_tool_call") {
|
||||
opts.input_message = input_message_type::CONTENT_TOOL_CALL;
|
||||
} else if (type == "reasoning_tool_call") {
|
||||
opts.input_message = input_message_type::REASONING_TOOL_CALL;
|
||||
} else if (type == "content_fake_tool_call") {
|
||||
opts.input_message = input_message_type::CONTENT_FAKE_TOOL_CALL;
|
||||
} else if (type == "all") {
|
||||
opts.input_message = input_message_type::ALL;
|
||||
} else {
|
||||
LOG_ERR("Unknown input message type: %s\n", type.c_str());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG_ERR("Unknown option: %s\n", arg.c_str());
|
||||
print_usage(argv[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static json build_user_message() {
|
||||
return json{
|
||||
{ "role", "user" },
|
||||
{ "content", "Hello, please help me with a task." }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_content_only_message() {
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello! I'm here to help you with your task." }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_reasoning_content_message() {
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "Hello! I'm here to help you with your task." },
|
||||
{ "reasoning_content", "The user is greeting me and asking for help. I should respond politely." }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_tool_call_only_message() {
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", nullptr },
|
||||
{ "tool_calls",
|
||||
json::array({ json{
|
||||
{ "type", "function" },
|
||||
{ "function", json{ { "name", "test_function_name" },
|
||||
{ "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } },
|
||||
{ "id", "123456789" } } }) }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_content_tool_call_message() {
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "I'll help you by calling a function." },
|
||||
{ "tool_calls",
|
||||
json::array({ json{
|
||||
{ "type", "function" },
|
||||
{ "function",
|
||||
json{ { "name", "test_function_name" },
|
||||
{ "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_reasoning_tool_call_message() {
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", nullptr },
|
||||
{ "reasoning_content", "I need to call a function to help with this task." },
|
||||
{ "tool_calls",
|
||||
json::array({ json{
|
||||
{ "type", "function" },
|
||||
{ "function",
|
||||
json{ { "name", "test_function_name" },
|
||||
{ "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_content_fake_tool_call_message() {
|
||||
// This message has content but NO tool_calls field
|
||||
// It's used to test if a template renders tool definitions but not tool calls
|
||||
return json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", "I'll help you by calling a function." }
|
||||
};
|
||||
}
|
||||
|
||||
static json build_tools_definition() {
|
||||
json parameters_schema = json::object();
|
||||
parameters_schema["type"] = "object";
|
||||
parameters_schema["properties"] = json::object();
|
||||
parameters_schema["properties"]["param1"] = json::object({
|
||||
{ "type", "string" },
|
||||
{ "description", "First parameter" }
|
||||
});
|
||||
parameters_schema["properties"]["param2"] = json::object({
|
||||
{ "type", "string" },
|
||||
{ "description", "Second parameter" }
|
||||
});
|
||||
parameters_schema["required"] = json::array({ "param1" });
|
||||
|
||||
return json::array({
|
||||
json{ { "type", "function" },
|
||||
{ "function", json{ { "name", "test_function_name" },
|
||||
{ "description", "A test function for debugging" },
|
||||
{ "parameters", parameters_schema } } } }
|
||||
});
|
||||
}
|
||||
|
||||
static void render_scenario(const common_chat_template & tmpl,
|
||||
const std::string & scenario_name,
|
||||
const json & messages,
|
||||
const json & tools,
|
||||
bool add_generation_prompt,
|
||||
bool enable_thinking) {
|
||||
LOG_ERR("\n=== Scenario: %s ===\n", scenario_name.c_str());
|
||||
LOG_ERR("add_generation_prompt: %s, enable_thinking: %s\n", add_generation_prompt ? "true" : "false",
|
||||
enable_thinking ? "true" : "false");
|
||||
|
||||
// When add_generation_prompt is true, add a trailing user message to trigger the prompt
|
||||
json final_messages = messages;
|
||||
if (add_generation_prompt && !messages.empty() && messages.back().value("role", "") == "assistant") {
|
||||
final_messages.push_back(json{
|
||||
{ "role", "user" },
|
||||
{ "content", "Now please continue with another response." }
|
||||
});
|
||||
}
|
||||
|
||||
LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str());
|
||||
|
||||
try {
|
||||
autoparser::templates_params inputs;
|
||||
inputs.messages = final_messages;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context["enable_thinking"] = enable_thinking;
|
||||
|
||||
if (!tools.is_null() && tools.is_array() && !tools.empty()) {
|
||||
inputs.tools = tools;
|
||||
}
|
||||
|
||||
std::string output = common_chat_template_direct_apply(tmpl, inputs);
|
||||
|
||||
LOG_ERR("\n--- Rendered Output ---\n");
|
||||
LOG_ERR("%s\n", output.c_str());
|
||||
LOG_ERR("--- End Output (length: %zu) ---\n", output.length());
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Rendering failed: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
static void render_all_scenarios(const common_chat_template & tmpl,
|
||||
const json & tools,
|
||||
bool add_generation_prompt,
|
||||
bool enable_thinking,
|
||||
input_message_type message_type) {
|
||||
json user_msg = build_user_message();
|
||||
|
||||
auto render_if = [&](input_message_type type, const std::string & name, const json & assistant_msg) {
|
||||
if (message_type == input_message_type::ALL || message_type == type) {
|
||||
json messages = json::array({ user_msg, assistant_msg });
|
||||
render_scenario(tmpl, name, messages, tools, add_generation_prompt, enable_thinking);
|
||||
}
|
||||
};
|
||||
|
||||
render_if(input_message_type::CONTENT_ONLY, "content_only", build_content_only_message());
|
||||
render_if(input_message_type::REASONING_CONTENT, "reasoning_content", build_reasoning_content_message());
|
||||
render_if(input_message_type::TOOL_CALL_ONLY, "tool_call_only", build_tool_call_only_message());
|
||||
render_if(input_message_type::CONTENT_TOOL_CALL, "content_tool_call", build_content_tool_call_message());
|
||||
render_if(input_message_type::REASONING_TOOL_CALL, "reasoning_tool_call", build_reasoning_tool_call_message());
|
||||
render_if(input_message_type::CONTENT_FAKE_TOOL_CALL, "content_fake_tool_call",
|
||||
build_content_fake_tool_call_message());
|
||||
|
||||
// Also render with add_generation_prompt=true to show the prompt ending
|
||||
if (message_type == input_message_type::ALL) {
|
||||
LOG_ERR("\n\n=== Generation Prompt Scenarios (add_generation_prompt=true) ===\n");
|
||||
|
||||
json prompt_messages = json::array({ user_msg });
|
||||
render_scenario(tmpl, "generation_prompt_only", prompt_messages, tools, true, enable_thinking);
|
||||
|
||||
// With enable_thinking toggled
|
||||
render_scenario(tmpl, "generation_prompt_thinking_disabled", prompt_messages, tools, true, false);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// Set log level to most verbose to capture all debug output
|
||||
common_log_set_verbosity_thold(99);
|
||||
|
||||
debug_options opts;
|
||||
if (!parse_options(argc, argv, opts)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (opts.debug_jinja || std::getenv("LLAMA_DEBUG_JINJA") != nullptr) {
|
||||
jinja::enable_debug(true);
|
||||
}
|
||||
|
||||
std::string template_source;
|
||||
try {
|
||||
// Check if the file is a GGUF file
|
||||
if (opts.template_path.size() >= 5 &&
|
||||
opts.template_path.compare(opts.template_path.size() - 5, 5, ".gguf") == 0) {
|
||||
template_source = read_gguf_chat_template(opts.template_path);
|
||||
} else {
|
||||
template_source = read_file(opts.template_path);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Error reading template: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG_ERR("Analyzing template: %s\n", opts.template_path.c_str());
|
||||
LOG_ERR("Options: with_tools=%s, generation_prompt=%s, enable_reasoning=%s\n", opts.with_tools ? "true" : "false",
|
||||
opts.generation_prompt ? "true" : "false", opts.enable_reasoning ? "true" : "false");
|
||||
|
||||
try {
|
||||
common_chat_template chat_template(template_source, "", "");
|
||||
|
||||
// Build tools definition
|
||||
json tools = opts.with_tools ? build_tools_definition() : json();
|
||||
|
||||
// Render template scenarios if requested
|
||||
if (opts.input_message != input_message_type::NONE &&
|
||||
(opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE RENDERING OUTPUT\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
|
||||
render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning,
|
||||
opts.input_message);
|
||||
}
|
||||
|
||||
// Output analysis if requested
|
||||
if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE ANALYSIS\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
|
||||
autoparser::autoparser analysis;
|
||||
analysis.analyze_template(chat_template);
|
||||
|
||||
// Generate Parser
|
||||
autoparser::templates_params params;
|
||||
params.messages = json::array({ build_user_message() });
|
||||
params.reasoning_format =
|
||||
opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE;
|
||||
params.enable_thinking = opts.enable_reasoning;
|
||||
params.add_generation_prompt = opts.generation_prompt;
|
||||
|
||||
if (opts.with_tools) {
|
||||
params.tools = tools;
|
||||
params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
} else {
|
||||
params.tools = json();
|
||||
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
}
|
||||
params.parallel_tool_calls = false;
|
||||
|
||||
auto parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis);
|
||||
|
||||
LOG_ERR("\n=== Generated Parser ===\n");
|
||||
common_peg_arena arena;
|
||||
arena.load(parser_data.parser);
|
||||
LOG_ERR("%s\n", arena.dump(arena.root()).c_str());
|
||||
|
||||
LOG_ERR("\n=== Generated Grammar ===\n");
|
||||
LOG_ERR("%s\n", parser_data.grammar.c_str());
|
||||
|
||||
LOG_ERR("\n=== Generated Lazy Grammar ===\n");
|
||||
LOG_ERR("%d\n", parser_data.grammar_lazy);
|
||||
|
||||
LOG_ERR("\n=== Generated Grammar Triggers ===\n");
|
||||
for (const common_grammar_trigger & cgt : parser_data.grammar_triggers) {
|
||||
LOG_ERR("Token: %d | Type: %d | Value: %s\n", cgt.token, cgt.type, cgt.value.c_str());
|
||||
}
|
||||
|
||||
LOG_ERR("\n=== Preserved Tokens ===\n");
|
||||
for (const std::string & token : parser_data.preserved_tokens) {
|
||||
LOG_ERR(" '%s'\n", token.c_str());
|
||||
}
|
||||
|
||||
if (!parser_data.grammar.empty()) {
|
||||
LOG_ERR("\n=== Verifying created grammar ===\n");
|
||||
auto * grammar = llama_grammar_init_impl(nullptr, parser_data.grammar.c_str(), "root",
|
||||
parser_data.grammar_lazy, nullptr, 0, nullptr, 0);
|
||||
if (grammar != nullptr) {
|
||||
LOG_ERR("\n=== Grammar successfully created ===\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Analysis failed: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,611 @@
|
|||
#include "chat-auto-parser.h"
|
||||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat.h"
|
||||
#include "log.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "jinja/runtime.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// ANSI color codes - using 256-color palette for brighter colors (all bold)
|
||||
#define ANSI_RESET "\033[0m"
|
||||
#define ANSI_PURPLE "\033[1m\x1b[38;5;126m" // Bold bright purple for main headers
|
||||
#define ANSI_CYAN "\033[1m\x1b[38;5;81m" // Bold bright cyan for section headers
|
||||
#define ANSI_BLUE "\033[1m\x1b[38;5;12m" // Bold bright blue for labels
|
||||
#define ANSI_ORANGE "\033[1m\x1b[38;5;209m" // Bold orange for right differences
|
||||
#define ANSI_GREEN "\033[1m\x1b[38;5;83m" // Bold bright green for left differences
|
||||
#define ANSI_GRAY "\033[1m\x1b[38;5;240m" // Bold gray (used for "no variables" message)
|
||||
#define ANSI_BOLD "\033[1m" // Standalone bold
|
||||
#define ANSI_PREFIX "\033[1m\x1b[38;5;176m" // Bold color for common prefix
|
||||
#define ANSI_SUFFIX "\033[1m\x1b[38;5;61m" // Bold color for common suffix
|
||||
|
||||
// All template paths extracted from tests/test-chat.cpp
|
||||
static const std::vector<std::string> ALL_TEMPLATE_PATHS = {
|
||||
"models/templates/Apertus-8B-Instruct.jinja",
|
||||
"models/templates/Apriel-1.6-15b-Thinker-fixed.jinja",
|
||||
"models/templates/ByteDance-Seed-OSS.jinja",
|
||||
"models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja",
|
||||
"models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja",
|
||||
"models/templates/GLM-4.6.jinja",
|
||||
"models/templates/GLM-4.7-Flash.jinja",
|
||||
"models/templates/Kimi-K2-Instruct.jinja",
|
||||
"models/templates/Kimi-K2-Thinking.jinja",
|
||||
"models/templates/MiMo-VL.jinja",
|
||||
"models/templates/MiniMax-M2.jinja",
|
||||
"models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja",
|
||||
"models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja",
|
||||
"models/templates/NVIDIA-Nemotron-Nano-v2.jinja",
|
||||
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja",
|
||||
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja",
|
||||
"models/templates/Qwen-QwQ-32B.jinja",
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja",
|
||||
"models/templates/Qwen3-Coder.jinja",
|
||||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja",
|
||||
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja",
|
||||
"models/templates/deepseek-ai-DeepSeek-V3.1.jinja",
|
||||
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja",
|
||||
"models/templates/google-gemma-2-2b-it.jinja",
|
||||
"models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja",
|
||||
"models/templates/llama-cpp-deepseek-r1.jinja",
|
||||
"models/templates/meetkai-functionary-medium-v3.1.jinja",
|
||||
"models/templates/meetkai-functionary-medium-v3.2.jinja",
|
||||
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja",
|
||||
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja",
|
||||
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja",
|
||||
"models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja",
|
||||
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja",
|
||||
"models/templates/moonshotai-Kimi-K2.jinja",
|
||||
"models/templates/openai-gpt-oss-120b.jinja",
|
||||
"models/templates/unsloth-Apriel-1.5.jinja",
|
||||
"models/templates/unsloth-mistral-Devstral-Small-2507.jinja",
|
||||
};
|
||||
|
||||
struct analysis_options {
|
||||
std::vector<std::string> template_paths;
|
||||
bool analyze_all = false;
|
||||
};
|
||||
|
||||
static std::string read_file(const std::string & path) {
|
||||
std::ifstream fin(path, std::ios::binary);
|
||||
if (!fin.is_open()) {
|
||||
throw std::runtime_error("Could not open file: " + path);
|
||||
}
|
||||
std::ostringstream buf;
|
||||
buf << fin.rdbuf();
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
static void print_usage(const char * program_name) {
|
||||
LOG_ERR("Usage: %s [options]\n", program_name);
|
||||
LOG_ERR("\nOptions:\n");
|
||||
LOG_ERR(" --template <name> Analyze specific template from test suite (e.g., 'deepseek' or 'DeepSeek-V3.1')\n");
|
||||
LOG_ERR(" --template-file <path> Analyze custom template file\n");
|
||||
LOG_ERR(" --all Analyze all templates from test suite\n");
|
||||
LOG_ERR("\nExamples:\n");
|
||||
LOG_ERR(" %s --all\n", program_name);
|
||||
LOG_ERR(" %s --template deepseek\n", program_name);
|
||||
LOG_ERR(" %s --template-file my-template.jinja\n", program_name);
|
||||
}
|
||||
|
||||
static bool parse_options(int argc, char ** argv, analysis_options & opts) {
|
||||
if (argc < 2) {
|
||||
print_usage(argv[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "--all") {
|
||||
opts.analyze_all = true;
|
||||
} else if (arg == "--template") {
|
||||
if (i + 1 >= argc) {
|
||||
LOG_ERR("--template requires an argument\n");
|
||||
return false;
|
||||
}
|
||||
std::string pattern = argv[++i];
|
||||
std::transform(pattern.begin(), pattern.end(), pattern.begin(), ::tolower);
|
||||
|
||||
// Find matching templates
|
||||
bool found = false;
|
||||
for (const auto & path : ALL_TEMPLATE_PATHS) {
|
||||
std::string path_lower = path;
|
||||
std::transform(path_lower.begin(), path_lower.end(), path_lower.begin(), ::tolower);
|
||||
if (path_lower.find(pattern) != std::string::npos) {
|
||||
opts.template_paths.push_back(path);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
LOG_ERR("No templates found matching: %s\n", pattern.c_str());
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "--template-file") {
|
||||
if (i + 1 >= argc) {
|
||||
LOG_ERR("--template-file requires an argument\n");
|
||||
return false;
|
||||
}
|
||||
opts.template_paths.push_back(argv[++i]);
|
||||
} else {
|
||||
LOG_ERR("Unknown option: %s\n", arg.c_str());
|
||||
print_usage(argv[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (opts.analyze_all) {
|
||||
opts.template_paths = ALL_TEMPLATE_PATHS;
|
||||
}
|
||||
|
||||
if (opts.template_paths.empty()) {
|
||||
LOG_ERR("No templates specified\n");
|
||||
print_usage(argv[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static json build_tools_definition() {
|
||||
json parameters_schema = json::object();
|
||||
parameters_schema["type"] = "object";
|
||||
parameters_schema["properties"] = json::object();
|
||||
parameters_schema["properties"]["param1"] = json::object({
|
||||
{ "type", "string" },
|
||||
{ "description", "First parameter" }
|
||||
});
|
||||
parameters_schema["properties"]["param2"] = json::object({
|
||||
{ "type", "string" },
|
||||
{ "description", "Second parameter" }
|
||||
});
|
||||
parameters_schema["required"] = json::array({ "param1", "param2" });
|
||||
|
||||
return json::array({
|
||||
json{ { "type", "function" },
|
||||
{ "function", json{ { "name", "test_function_name" },
|
||||
{ "description", "A test function for debugging" },
|
||||
{ "parameters", parameters_schema } } } }
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to create a tool call with arguments as JSON object
|
||||
static json build_tool_call(const std::string & name, const json & args_object, const std::string & id = "call_001") {
|
||||
return json{
|
||||
{"id", id},
|
||||
{"type", "function"},
|
||||
{"function", json{
|
||||
{"name", name},
|
||||
{"arguments", args_object} // Pass as JSON object, not serialized string
|
||||
}}
|
||||
};
|
||||
}
|
||||
|
||||
// Helper functions to create repeating message definitions
|
||||
static json make_user_msg() {
|
||||
return json{
|
||||
{"role", "user"},
|
||||
{"content", "Hello, please help me."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_user_msg2() {
|
||||
return json{
|
||||
{"role", "user"},
|
||||
{"content", "Thank you."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_user_msg2_continue() {
|
||||
return json{
|
||||
{"role", "user"},
|
||||
{"content", "Continue."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_no_tool() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", "Let me help you."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_one_tool() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", json::array({
|
||||
build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}}))
|
||||
})}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_two_tools() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", json::array({
|
||||
build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})),
|
||||
build_tool_call("test_function_name", json::object({{"param1", "value3"}, {"param2", "value4"}}), "call_002")
|
||||
})}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_no_reasoning() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", "I can help you with that."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_with_reasoning() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", "I can help you with that."},
|
||||
{"reasoning_content", "The user is asking for help. I should respond positively."}
|
||||
};
|
||||
}
|
||||
|
||||
static json make_assistant_one_tool_with_reasoning() {
|
||||
return json{
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", json::array({
|
||||
build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}}))
|
||||
})},
|
||||
{"reasoning_content", "I need to call the tool first."}
|
||||
};
|
||||
}
|
||||
|
||||
static void print_diff_split(const std::string & title, const diff_split & diff) {
|
||||
LOG_ERR("\n%s=== %s ===%s\n", ANSI_CYAN, title.c_str(), ANSI_RESET);
|
||||
LOG_ERR("%sCommon Prefix:%s '%s'\n", ANSI_PREFIX, ANSI_RESET, diff.prefix.c_str());
|
||||
LOG_ERR("%sCommon Suffix:%s '%s'\n", ANSI_SUFFIX, ANSI_RESET, diff.suffix.c_str());
|
||||
LOG_ERR("%sLeft (difference):%s '%s'\n", ANSI_GREEN, ANSI_RESET, diff.left.c_str());
|
||||
LOG_ERR("%sRight (difference):%s '%s'\n", ANSI_ORANGE, ANSI_RESET, diff.right.c_str());
|
||||
}
|
||||
|
||||
static void check_reasoning_variables(const common_chat_template & tmpl) {
|
||||
LOG_ERR("\n%s=== Checking Reasoning Variables ===%s\n", ANSI_CYAN, ANSI_RESET);
|
||||
|
||||
try {
|
||||
// Create a list of candidate reasoning/thinking variable names to probe
|
||||
std::vector<std::string> candidate_vars = {
|
||||
"enable_reasoning",
|
||||
"use_reasoning",
|
||||
"reasoning_enabled",
|
||||
"has_reasoning",
|
||||
"reasoning_mode",
|
||||
"reasoning_format",
|
||||
"reasoning_active",
|
||||
"with_reasoning",
|
||||
"use_thinking",
|
||||
"thinking_enabled",
|
||||
"has_thinking",
|
||||
"thinking_mode",
|
||||
"thinking_format",
|
||||
"thinking_active",
|
||||
"with_thinking",
|
||||
"enable_reason",
|
||||
"reason_enabled",
|
||||
"enable_think",
|
||||
"think_enabled",
|
||||
};
|
||||
|
||||
jinja::context ctx;
|
||||
ctx.is_get_stats = true;
|
||||
|
||||
json messages = json::array({
|
||||
json{
|
||||
{"role", "user"},
|
||||
{"content", "Test message"}
|
||||
},
|
||||
json{
|
||||
{"role", "assistant"},
|
||||
{"content", "Response"},
|
||||
{"reasoning_content", "Some reasoning"}
|
||||
}
|
||||
});
|
||||
|
||||
// Set up base context
|
||||
jinja::global_from_json(ctx, json{
|
||||
{"messages", messages},
|
||||
{"tools", json::array()},
|
||||
{"bos_token", ""},
|
||||
{"eos_token", ""},
|
||||
{"add_generation_prompt", false},
|
||||
{"enable_thinking", true} // Already passed, so we'll exclude this from results
|
||||
}, true);
|
||||
|
||||
// Add candidate variables as undefined to probe which ones are accessed
|
||||
for (const auto & var_name : candidate_vars) {
|
||||
ctx.set_val(var_name, jinja::mk_val<jinja::value_undefined_t>(var_name));
|
||||
}
|
||||
|
||||
try {
|
||||
jinja::runtime runtime(ctx);
|
||||
runtime.execute(tmpl.prog);
|
||||
} catch (const std::exception & e) {
|
||||
// Execution may fail, that's okay - we just want to see what variables were accessed
|
||||
}
|
||||
|
||||
// Check which candidate variables were accessed (stats.used = true)
|
||||
std::vector<std::string> accessed_vars;
|
||||
for (const auto & var_name : candidate_vars) {
|
||||
auto val = ctx.get_val(var_name);
|
||||
if (!val->is_undefined()) {
|
||||
// Variable was overwritten, skip it
|
||||
continue;
|
||||
}
|
||||
if (val->stats.used) {
|
||||
accessed_vars.push_back(var_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (accessed_vars.empty()) {
|
||||
LOG_ERR("%sNo reasoning/thinking-related variables were queried by the template%s\n", ANSI_GRAY, ANSI_RESET);
|
||||
} else {
|
||||
LOG_ERR("Template queries the following reasoning/thinking-related variables:\n");
|
||||
for (const auto & var : accessed_vars) {
|
||||
LOG_ERR(" %s- %s%s\n", ANSI_ORANGE, var.c_str(), ANSI_RESET);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Error checking reasoning variables: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
static void analyze_template(const std::string & template_path) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("%s", ANSI_PURPLE);
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" ANALYZING TEMPLATE: %s\n", template_path.c_str());
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR("%s", ANSI_RESET);
|
||||
|
||||
std::string template_source;
|
||||
try {
|
||||
template_source = read_file(template_path);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Error reading template: %s\n", e.what());
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
common_chat_template chat_template(template_source, "", "");
|
||||
json tools = build_tools_definition();
|
||||
|
||||
// ===== CAPABILITIES ANALYSIS =====
|
||||
LOG_ERR("\n%s=== Template Capabilities (from jinja::caps) ===%s\n", ANSI_CYAN, ANSI_RESET);
|
||||
auto caps = chat_template.original_caps();
|
||||
LOG_ERR("%ssupports_tools:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tools ? "true" : "false");
|
||||
LOG_ERR("%ssupports_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tool_calls ? "true" : "false");
|
||||
LOG_ERR("%ssupports_system_role:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_system_role ? "true" : "false");
|
||||
LOG_ERR("%ssupports_parallel_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_parallel_tool_calls ? "true" : "false");
|
||||
LOG_ERR("%ssupports_typed_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_typed_content ? "true" : "false");
|
||||
LOG_ERR("%ssupports_string_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_string_content ? "true" : "false");
|
||||
|
||||
// ===== DIFFERENTIAL ANALYSIS =====
|
||||
|
||||
// Test 1: With and without tools (single user message)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_tools;
|
||||
params_no_tools.messages = json::array({ user_msg });
|
||||
params_no_tools.add_generation_prompt = false;
|
||||
params_no_tools.tools = json::array();
|
||||
|
||||
autoparser::templates_params params_with_tools = params_no_tools;
|
||||
params_with_tools.tools = tools;
|
||||
|
||||
std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools);
|
||||
std::string output_with_tools = common_chat_template_direct_apply(chat_template, params_with_tools);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_tools, output_with_tools);
|
||||
print_diff_split("Diff: With vs Without Tools (single user message)", diff);
|
||||
}
|
||||
|
||||
// Test 2: With and without add_generation_prompt (single user message)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_prompt;
|
||||
params_no_prompt.messages = json::array({ user_msg });
|
||||
params_no_prompt.add_generation_prompt = false;
|
||||
params_no_prompt.tools = json::array();
|
||||
|
||||
autoparser::templates_params params_with_prompt = params_no_prompt;
|
||||
params_with_prompt.add_generation_prompt = true;
|
||||
|
||||
std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt);
|
||||
std::string output_with_prompt = common_chat_template_direct_apply(chat_template, params_with_prompt);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_prompt, output_with_prompt);
|
||||
print_diff_split("Diff: With vs Without add_generation_prompt (single user message)", diff);
|
||||
}
|
||||
|
||||
// Test 3: Assistant with reasoning_content (user, assistant)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning);
|
||||
print_diff_split("Diff: With vs Without reasoning_content (user, assistant)", diff);
|
||||
}
|
||||
|
||||
// Test 4: Assistant with reasoning_content (user, assistant, user)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning);
|
||||
print_diff_split("Diff: With vs Without reasoning_content (user, assistant, user)", diff);
|
||||
}
|
||||
|
||||
// Test 5: Tool call in last assistant message (user, assistant)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_tool;
|
||||
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() });
|
||||
params_no_tool.add_generation_prompt = false;
|
||||
params_no_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_with_tool = params_no_tool;
|
||||
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
|
||||
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
|
||||
std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_tool, output_with_tool);
|
||||
print_diff_split("Diff: With vs Without tool call (user, assistant)", diff);
|
||||
}
|
||||
|
||||
// Test 6: Tool call in last assistant message (user, assistant, user)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2_continue();
|
||||
|
||||
autoparser::templates_params params_no_tool;
|
||||
params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 });
|
||||
params_no_tool.add_generation_prompt = false;
|
||||
params_no_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_with_tool = params_no_tool;
|
||||
params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
|
||||
|
||||
std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool);
|
||||
std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_tool, output_with_tool);
|
||||
print_diff_split("Diff: With vs Without tool call (user, assistant, user)", diff);
|
||||
}
|
||||
|
||||
// Test 7: One vs two tool calls (user, assistant)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_one_tool;
|
||||
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
params_one_tool.add_generation_prompt = false;
|
||||
params_one_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_two_tools = params_one_tool;
|
||||
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() });
|
||||
|
||||
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
|
||||
std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools);
|
||||
|
||||
auto diff = calculate_diff_split(output_one_tool, output_two_tools);
|
||||
print_diff_split("Diff: One vs Two tool calls (user, assistant)", diff);
|
||||
}
|
||||
|
||||
// Test 8: One vs two tool calls (user, assistant, user)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
json user_msg2 = make_user_msg2_continue();
|
||||
|
||||
autoparser::templates_params params_one_tool;
|
||||
params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 });
|
||||
params_one_tool.add_generation_prompt = false;
|
||||
params_one_tool.tools = tools;
|
||||
|
||||
autoparser::templates_params params_two_tools = params_one_tool;
|
||||
params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 });
|
||||
|
||||
std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool);
|
||||
std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools);
|
||||
|
||||
auto diff = calculate_diff_split(output_one_tool, output_two_tools);
|
||||
print_diff_split("Diff: One vs Two tool calls (user, assistant, user)", diff);
|
||||
}
|
||||
|
||||
// Test 9: Tool call with vs without reasoning_content (user, assistant)
|
||||
{
|
||||
json user_msg = make_user_msg();
|
||||
|
||||
autoparser::templates_params params_no_reasoning;
|
||||
params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() });
|
||||
params_no_reasoning.add_generation_prompt = false;
|
||||
params_no_reasoning.tools = tools;
|
||||
params_no_reasoning.enable_thinking = true;
|
||||
|
||||
autoparser::templates_params params_with_reasoning = params_no_reasoning;
|
||||
params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() });
|
||||
|
||||
std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning);
|
||||
std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning);
|
||||
|
||||
auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning);
|
||||
print_diff_split("Diff: Tool call with vs without reasoning_content (user, assistant)", diff);
|
||||
}
|
||||
|
||||
// Check reasoning variables
|
||||
check_reasoning_variables(chat_template);
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Analysis failed: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// Set log level to capture all output
|
||||
common_log_set_verbosity_thold(99);
|
||||
|
||||
analysis_options opts;
|
||||
if (!parse_options(argc, argv, opts)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("%s", ANSI_PURPLE);
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE ANALYSIS TOOL\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR("%s", ANSI_RESET);
|
||||
LOG_ERR("Analyzing %s%zu%s template(s)\n", ANSI_CYAN, opts.template_paths.size(), ANSI_RESET);
|
||||
|
||||
for (const auto & path : opts.template_paths) {
|
||||
analyze_template(path);
|
||||
}
|
||||
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("%s", ANSI_GREEN);
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" ANALYSIS COMPLETE\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR("%s", ANSI_RESET);
|
||||
|
||||
return 0;
|
||||
}
|
||||
Binary file not shown.
|
|
@ -729,6 +729,10 @@ export class SchemaConverter {
|
|||
return this._addRule(ruleName, out.join(''));
|
||||
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
|
||||
return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object']));
|
||||
} else if (schemaType === undefined && typeof schema === 'object' && !Array.isArray(schema) && schema !== null) {
|
||||
// No type constraint and no recognized structural keywords (e.g. {"description": "..."}).
|
||||
// Per JSON Schema semantics this is equivalent to {} and accepts any value.
|
||||
return this._addRule(ruleName, this._addPrimitive('value', PRIMITIVE_RULES['value']));
|
||||
} else {
|
||||
if (!(schemaType in PRIMITIVE_RULES)) {
|
||||
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);
|
||||
|
|
|
|||
|
|
@ -1463,6 +1463,7 @@ json convert_anthropic_to_oai(const json & body) {
|
|||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
std::string reasoning_content;
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
|
|
@ -1470,6 +1471,8 @@ json convert_anthropic_to_oai(const json & body) {
|
|||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "thinking") {
|
||||
reasoning_content += json_value(block, "thinking", std::string());
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
|
@ -1528,16 +1531,19 @@ json convert_anthropic_to_oai(const json & body) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
} else if (has_tool_calls || !reasoning_content.empty()) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
new_msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
|
|
@ -2348,8 +2349,10 @@ private:
|
|||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
[&, func_name = __func__](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
|
||||
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
|
@ -2533,48 +2536,66 @@ private:
|
|||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
slot.init_sampler();
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot,
|
||||
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
||||
", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size =
|
||||
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
|
||||
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
||||
", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (!slot_batched) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "http.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
||||
std::string target_url = req.get_param("url");
|
||||
common_http_url parsed_url = common_http_parse_url(target_url);
|
||||
|
||||
if (parsed_url.host.empty()) {
|
||||
throw std::runtime_error("invalid target URL: missing host");
|
||||
}
|
||||
|
||||
if (parsed_url.path.empty()) {
|
||||
parsed_url.path = "/";
|
||||
}
|
||||
|
||||
if (!parsed_url.password.empty()) {
|
||||
throw std::runtime_error("authentication in target URL is not supported");
|
||||
}
|
||||
|
||||
if (parsed_url.scheme != "http" && parsed_url.scheme != "https") {
|
||||
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
|
||||
}
|
||||
|
||||
SRV_INF("proxying %s request to %s://%s%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.path.c_str());
|
||||
|
||||
auto proxy = std::make_unique<server_http_proxy>(
|
||||
method,
|
||||
parsed_url.host,
|
||||
parsed_url.scheme == "http" ? 80 : 443,
|
||||
parsed_url.path,
|
||||
req.headers,
|
||||
req.body,
|
||||
req.should_stop,
|
||||
600, // timeout_read (default to 10 minutes)
|
||||
600 // timeout_write (default to 10 minutes)
|
||||
);
|
||||
|
||||
return proxy;
|
||||
}
|
||||
|
||||
static server_http_context::handler_t proxy_handler_post = [](const server_http_req & req) -> server_http_res_ptr {
|
||||
return proxy_request(req, "POST");
|
||||
};
|
||||
|
||||
static server_http_context::handler_t proxy_handler_get = [](const server_http_req & req) -> server_http_res_ptr {
|
||||
return proxy_request(req, "GET");
|
||||
};
|
||||
|
|
@ -1089,11 +1089,20 @@ server_http_proxy::server_http_proxy(
|
|||
int32_t timeout_write
|
||||
) {
|
||||
// shared between reader and writer threads
|
||||
auto cli = std::make_shared<httplib::Client>(host, port);
|
||||
auto cli = std::make_shared<httplib::ClientImpl>(host, port);
|
||||
auto pipe = std::make_shared<pipe_t<msg_t>>();
|
||||
|
||||
if (port == 443) {
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
cli.reset(new httplib::SSLClient(host, port));
|
||||
#else
|
||||
throw std::runtime_error("HTTPS requested but CPPHTTPLIB_OPENSSL_SUPPORT is not defined");
|
||||
#endif
|
||||
}
|
||||
|
||||
// setup Client
|
||||
cli->set_connection_timeout(0, 200000); // 200 milliseconds
|
||||
cli->set_follow_location(true);
|
||||
cli->set_connection_timeout(5, 0); // 5 seconds
|
||||
cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
|
||||
cli->set_read_timeout(timeout_write, 0);
|
||||
this->status = 500; // to be overwritten upon response
|
||||
|
|
@ -1142,7 +1151,15 @@ server_http_proxy::server_http_proxy(
|
|||
req.method = method;
|
||||
req.path = path;
|
||||
for (const auto & [key, value] : headers) {
|
||||
req.set_header(key, value);
|
||||
if (key == "Accept-Encoding") {
|
||||
// disable Accept-Encoding to avoid compressed responses
|
||||
continue;
|
||||
}
|
||||
if (key == "Host" || key == "host") {
|
||||
req.set_header(key, host);
|
||||
} else {
|
||||
req.set_header(key, value);
|
||||
}
|
||||
}
|
||||
req.body = body;
|
||||
req.response_handler = response_handler;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
#include "server-common.h"
|
||||
#include "server-task.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "server-common.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
|
@ -157,7 +157,8 @@ json task_params::to_json(bool only_metrics) const {
|
|||
common_chat_msg task_result_state::update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs) {
|
||||
std::vector<common_chat_msg_diff> & diffs,
|
||||
bool filter_tool_calls) {
|
||||
generated_text += text_added;
|
||||
auto msg_prv_copy = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
|
|
@ -168,7 +169,64 @@ common_chat_msg task_result_state::update_chat_msg(
|
|||
if (!new_msg.empty()) {
|
||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
|
||||
auto all_diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, chat_msg);
|
||||
|
||||
if (!filter_tool_calls) {
|
||||
diffs = std::move(all_diffs);
|
||||
} else {
|
||||
for (auto & d : all_diffs) {
|
||||
// If this is a new type of delta, flush all currently pending tool call names
|
||||
for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) {
|
||||
if (sent_tool_call_names.count(i) || chat_msg.tool_calls[i].name.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (d.tool_call_index != i || !d.tool_call_delta.arguments.empty()) {
|
||||
common_chat_msg_diff header;
|
||||
header.tool_call_index = i;
|
||||
header.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
||||
header.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
||||
diffs.push_back(std::move(header));
|
||||
sent_tool_call_names.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (d.tool_call_index == std::string::npos) {
|
||||
diffs.push_back(std::move(d));
|
||||
} else {
|
||||
size_t i = d.tool_call_index;
|
||||
if (sent_tool_call_names.count(i)) {
|
||||
if (!d.tool_call_delta.arguments.empty()) {
|
||||
d.tool_call_delta.name = "";
|
||||
d.tool_call_delta.id = "";
|
||||
diffs.push_back(std::move(d));
|
||||
}
|
||||
} else {
|
||||
// Not sent yet.
|
||||
if (!d.tool_call_delta.arguments.empty() || !is_partial) {
|
||||
d.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
||||
d.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
||||
diffs.push_back(std::move(d));
|
||||
sent_tool_call_names.insert(i);
|
||||
} else {
|
||||
// Suppress
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final check at EOF
|
||||
if (!is_partial) {
|
||||
for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) {
|
||||
if (!sent_tool_call_names.count(i) && !chat_msg.tool_calls[i].name.empty()) {
|
||||
common_chat_msg_diff header;
|
||||
header.tool_call_index = i;
|
||||
header.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
||||
header.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
||||
diffs.push_back(std::move(header));
|
||||
sent_tool_call_names.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ struct task_result_state {
|
|||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
std::unordered_set<size_t> sent_tool_call_names;
|
||||
|
||||
// for OpenAI Responses and Anthropic streaming API:
|
||||
// track output item / content block state across chunks
|
||||
|
|
@ -120,7 +121,8 @@ struct task_result_state {
|
|||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
std::vector<common_chat_msg_diff> & diffs,
|
||||
bool filter_tool_calls = false);
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "server-context.h"
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
|
|
@ -201,6 +202,15 @@ int main(int argc, char ** argv) {
|
|||
// Save & load slots
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
|
||||
if (params.webui_mcp_proxy) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
|
|
|||
|
|
@ -809,6 +809,139 @@ def test_anthropic_vs_openai_different_response_format():
|
|||
|
||||
# Extended thinking tests with reasoning models
|
||||
|
||||
# The next two tests cover the input path (conversation history):
|
||||
# Client sends thinking blocks -> convert_anthropic_to_oai -> reasoning_content -> template
|
||||
|
||||
def test_anthropic_thinking_history_in_count_tokens():
|
||||
"""Test that interleaved thinking blocks in conversation history are not dropped during conversion."""
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = '../../../models/templates/Qwen-Qwen3-0.6B.jinja'
|
||||
server.start()
|
||||
|
||||
tool = {
|
||||
"name": "list_files",
|
||||
"description": "List files",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string"}},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
|
||||
messages_without_thinking = [
|
||||
{"role": "user", "content": "Fix the bug"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "list_files", "input": {"path": "."}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": "main.py"}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
messages_with_thinking = [
|
||||
{"role": "user", "content": "Fix the bug"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "I should check the project structure first to understand the codebase layout."},
|
||||
{"type": "tool_use", "id": "call_1", "name": "list_files", "input": {"path": "."}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": "main.py"}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
res_without = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": messages_without_thinking,
|
||||
"tools": [tool],
|
||||
})
|
||||
assert res_without.status_code == 200, f"Expected 200: {res_without.body}"
|
||||
|
||||
res_with = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": messages_with_thinking,
|
||||
"tools": [tool],
|
||||
})
|
||||
assert res_with.status_code == 200, f"Expected 200: {res_with.body}"
|
||||
|
||||
# Thinking blocks should increase the token count
|
||||
assert res_with.body["input_tokens"] > res_without.body["input_tokens"], \
|
||||
f"Expected more tokens with thinking ({res_with.body['input_tokens']}) than without ({res_without.body['input_tokens']})"
|
||||
|
||||
|
||||
def test_anthropic_thinking_history_in_template():
|
||||
"""Test that reasoning_content from converted interleaved thinking blocks renders in the prompt."""
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = '../../../models/templates/Qwen-Qwen3-0.6B.jinja'
|
||||
server.start()
|
||||
|
||||
reasoning_1 = "I should check the project structure first."
|
||||
reasoning_2 = "Now I need to read the main file."
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "Fix the bug in main.py"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": reasoning_1,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "list_files", "arguments": "{\"path\": \".\"}"}
|
||||
}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "main.py\nutils.py"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": reasoning_2,
|
||||
"tool_calls": [{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{\"path\": \"main.py\"}"}
|
||||
}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "print('hello')"},
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_files",
|
||||
"description": "List files",
|
||||
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]}
|
||||
}
|
||||
}, {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]}
|
||||
}
|
||||
}],
|
||||
})
|
||||
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
# Both reasoning_content values should be rendered in <think> tags
|
||||
assert reasoning_1 in prompt, f"Expected first reasoning text in prompt: {prompt}"
|
||||
assert reasoning_2 in prompt, f"Expected second reasoning text in prompt: {prompt}"
|
||||
assert prompt.count("<think>") >= 2, f"Expected at least 2 <think> blocks in prompt: {prompt}"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [False, True])
|
||||
def test_anthropic_thinking_with_reasoning_model(stream):
|
||||
|
|
|
|||
|
|
@ -100,18 +100,19 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"], f'Expected tool name to be {tool_call["function"]["name"]} in {choice["message"]}'
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, str)
|
||||
assert isinstance(actual_arguments, dict) or isinstance(actual_arguments, str), f'Expected arguments to be a dict or str, got: {actual_arguments}'
|
||||
if argument_key is not None:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
if (isinstance(actual_arguments, str)):
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {actual_arguments}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("Qwen3-Coder", TEST_TOOL, "success"),
|
||||
("Qwen3-Coder", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
|
|
|||
|
|
@ -12,9 +12,13 @@ flowchart TB
|
|||
C_Form["ChatForm"]
|
||||
C_Messages["ChatMessages"]
|
||||
C_Message["ChatMessage"]
|
||||
C_ChatMessageAgenticContent["ChatMessageAgenticContent"]
|
||||
C_MessageEditForm["ChatMessageEditForm"]
|
||||
C_ModelsSelector["ModelsSelector"]
|
||||
C_Settings["ChatSettings"]
|
||||
C_McpSettings["McpServersSettings"]
|
||||
C_McpResourceBrowser["McpResourceBrowser"]
|
||||
C_McpServersSelector["McpServersSelector"]
|
||||
end
|
||||
|
||||
subgraph Hooks["🪝 Hooks"]
|
||||
|
|
@ -24,10 +28,13 @@ flowchart TB
|
|||
|
||||
subgraph Stores["🗄️ Stores"]
|
||||
S1["chatStore<br/><i>Chat interactions & streaming</i>"]
|
||||
S2["conversationsStore<br/><i>Conversation data & messages</i>"]
|
||||
SA["agenticStore<br/><i>Multi-turn agentic loop orchestration</i>"]
|
||||
S2["conversationsStore<br/><i>Conversation data, messages & MCP overrides</i>"]
|
||||
S3["modelsStore<br/><i>Model selection & loading</i>"]
|
||||
S4["serverStore<br/><i>Server props & role detection</i>"]
|
||||
S5["settingsStore<br/><i>User configuration</i>"]
|
||||
S5["settingsStore<br/><i>User configuration incl. MCP</i>"]
|
||||
S6["mcpStore<br/><i>MCP servers, tools, prompts</i>"]
|
||||
S7["mcpResourceStore<br/><i>MCP resources & attachments</i>"]
|
||||
end
|
||||
|
||||
subgraph Services["⚙️ Services"]
|
||||
|
|
@ -36,11 +43,12 @@ flowchart TB
|
|||
SV3["PropsService"]
|
||||
SV4["DatabaseService"]
|
||||
SV5["ParameterSyncService"]
|
||||
SV6["MCPService<br/><i>protocol operations</i>"]
|
||||
end
|
||||
|
||||
subgraph Storage["💾 Storage"]
|
||||
ST1["IndexedDB<br/><i>conversations, messages</i>"]
|
||||
ST2["LocalStorage<br/><i>config, userOverrides</i>"]
|
||||
ST2["LocalStorage<br/><i>config, userOverrides, mcpServers</i>"]
|
||||
end
|
||||
|
||||
subgraph APIs["🌐 llama-server API"]
|
||||
|
|
@ -50,15 +58,27 @@ flowchart TB
|
|||
API4["/v1/models"]
|
||||
end
|
||||
|
||||
subgraph ExternalMCP["🔌 External MCP Servers"]
|
||||
EXT1["MCP Server 1<br/><i>WebSocket/HTTP/SSE</i>"]
|
||||
EXT2["MCP Server N"]
|
||||
end
|
||||
|
||||
%% Routes → Components
|
||||
R1 & R2 --> C_Screen
|
||||
RL --> C_Sidebar
|
||||
|
||||
%% Layout runs MCP health checks
|
||||
RL --> S6
|
||||
|
||||
%% Component hierarchy
|
||||
C_Screen --> C_Form & C_Messages & C_Settings
|
||||
C_Messages --> C_Message
|
||||
C_Message --> C_ChatMessageAgenticContent
|
||||
C_Message --> C_MessageEditForm
|
||||
C_Form & C_MessageEditForm --> C_ModelsSelector
|
||||
C_Form --> C_McpServersSelector
|
||||
C_Settings --> C_McpSettings
|
||||
C_McpSettings --> C_McpResourceBrowser
|
||||
|
||||
%% Components → Hooks → Stores
|
||||
C_Form & C_Messages --> H1 & H2
|
||||
|
|
@ -70,6 +90,15 @@ flowchart TB
|
|||
C_Sidebar --> S2
|
||||
C_ModelsSelector --> S3 & S4
|
||||
C_Settings --> S5
|
||||
C_McpSettings --> S6
|
||||
C_McpResourceBrowser --> S6 & S7
|
||||
C_McpServersSelector --> S6
|
||||
C_Form --> S6
|
||||
|
||||
%% chatStore → agenticStore → mcpStore (agentic loop)
|
||||
S1 --> SA
|
||||
SA --> SV1
|
||||
SA --> S6
|
||||
|
||||
%% Stores → Services
|
||||
S1 --> SV1 & SV4
|
||||
|
|
@ -77,6 +106,8 @@ flowchart TB
|
|||
S3 --> SV2 & SV3
|
||||
S4 --> SV3
|
||||
S5 --> SV5
|
||||
S6 --> SV6
|
||||
S7 --> SV6
|
||||
|
||||
%% Services → Storage
|
||||
SV4 --> ST1
|
||||
|
|
@ -87,6 +118,9 @@ flowchart TB
|
|||
SV2 --> API3 & API4
|
||||
SV3 --> API2
|
||||
|
||||
%% MCP → External Servers
|
||||
SV6 --> EXT1 & EXT2
|
||||
|
||||
%% Styling
|
||||
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
|
||||
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
|
||||
|
|
@ -95,12 +129,17 @@ flowchart TB
|
|||
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
|
||||
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
|
||||
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
|
||||
classDef mcpStyle fill:#e0f2f1,stroke:#00695c,stroke-width:2px
|
||||
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
|
||||
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
|
||||
|
||||
class R1,R2,RL routeStyle
|
||||
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
|
||||
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_ChatMessageAgenticContent,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
|
||||
class C_McpSettings,C_McpResourceBrowser,C_McpServersSelector componentStyle
|
||||
class H1,H2 hookStyle
|
||||
class S1,S2,S3,S4,S5 storeStyle
|
||||
class SV1,SV2,SV3,SV4,SV5 serviceStyle
|
||||
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
|
||||
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
|
||||
class ST1,ST2 storageStyle
|
||||
class API1,API2,API3,API4 apiStyle
|
||||
class EXT1,EXT2 externalStyle
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,6 +22,13 @@ end
|
|||
C_ModelsSelector["ModelsSelector"]
|
||||
C_Settings["ChatSettings"]
|
||||
end
|
||||
subgraph MCPComponents["MCP UI"]
|
||||
C_McpSettings["McpServersSettings"]
|
||||
C_McpServerCard["McpServerCard"]
|
||||
C_McpResourceBrowser["McpResourceBrowser"]
|
||||
C_McpResourcePreview["McpResourcePreview"]
|
||||
C_McpServersSelector["McpServersSelector"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Hooks["🪝 Hooks"]
|
||||
|
|
@ -43,14 +50,20 @@ end
|
|||
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()<br/>clearEditMode()<br/>isEditModeActive()<br/>getAddFilesHandler()<br/>setEditModeActive()"]
|
||||
S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"]
|
||||
end
|
||||
subgraph SA["agenticStore"]
|
||||
SAState["<b>State:</b><br/>sessions (Map)<br/>isAnyRunning"]
|
||||
SASession["<b>Session Management:</b><br/>getSession()<br/>updateSession()<br/>clearSession()<br/>getActiveSessions()<br/>isRunning()<br/>currentTurn()<br/>totalToolCalls()<br/>lastError()<br/>streamingToolCall()"]
|
||||
SAConfig["<b>Configuration:</b><br/>getConfig()<br/>maxTurns, maxToolPreviewLines"]
|
||||
SAFlow["<b>Agentic Loop:</b><br/>runAgenticFlow()<br/>executeAgenticLoop()<br/>normalizeToolCalls()<br/>emitToolCallResult()<br/>extractBase64Attachments()"]
|
||||
end
|
||||
subgraph S2["conversationsStore"]
|
||||
S2State["<b>State:</b><br/>conversations<br/>activeConversation<br/>activeMessages<br/>usedModalities<br/>isInitialized<br/>titleUpdateConfirmationCallback"]
|
||||
S2Modal["<b>Modalities:</b><br/>getModalitiesUpToMessage()<br/>calculateModalitiesFromMessages()"]
|
||||
S2State["<b>State:</b><br/>conversations<br/>activeConversation<br/>activeMessages<br/>isInitialized<br/>pendingMcpServerOverrides<br/>titleUpdateConfirmationCallback"]
|
||||
S2Lifecycle["<b>Lifecycle:</b><br/>initialize()<br/>loadConversations()<br/>clearActiveConversation()"]
|
||||
S2ConvCRUD["<b>Conversation CRUD:</b><br/>createConversation()<br/>loadConversation()<br/>deleteConversation()<br/>updateConversationName()<br/>updateConversationTitleWithConfirmation()"]
|
||||
S2ConvCRUD["<b>Conversation CRUD:</b><br/>createConversation()<br/>loadConversation()<br/>deleteConversation()<br/>deleteAll()<br/>updateConversationName()<br/>updateConversationTitleWithConfirmation()"]
|
||||
S2MsgMgmt["<b>Message Management:</b><br/>refreshActiveMessages()<br/>addMessageToActive()<br/>updateMessageAtIndex()<br/>findMessageIndex()<br/>sliceActiveMessages()<br/>removeMessageAtIndex()<br/>getConversationMessages()"]
|
||||
S2Nav["<b>Navigation:</b><br/>navigateToSibling()<br/>updateCurrentNode()<br/>updateConversationTimestamp()"]
|
||||
S2Export["<b>Import/Export:</b><br/>downloadConversation()<br/>exportAllConversations()<br/>importConversations()<br/>triggerDownload()"]
|
||||
S2McpOverrides["<b>MCP Per-Chat Overrides:</b><br/>getMcpServerOverride()<br/>getAllMcpServerOverrides()<br/>setMcpServerOverride()<br/>toggleMcpServerForChat()<br/>removeMcpServerOverride()<br/>isMcpServerEnabledForChat()<br/>clearPendingMcpServerOverrides()"]
|
||||
S2Export["<b>Import/Export:</b><br/>downloadConversation()<br/>exportAllConversations()<br/>importConversations()<br/>importConversationsData()<br/>triggerDownload()"]
|
||||
S2Utils["<b>Utilities:</b><br/>setTitleUpdateConfirmationCallback()"]
|
||||
end
|
||||
subgraph S3["modelsStore"]
|
||||
|
|
@ -77,6 +90,21 @@ end
|
|||
S5Sync["<b>Server Sync:</b><br/>syncWithServerDefaults()<br/>forceSyncWithServerDefaults()"]
|
||||
S5Utils["<b>Utilities:</b><br/>getConfig()<br/>getAllConfig()<br/>getParameterInfo()<br/>getParameterDiff()<br/>getServerDefaults()<br/>clearAllUserOverrides()"]
|
||||
end
|
||||
subgraph S6["mcpStore"]
|
||||
S6State["<b>State:</b><br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)"]
|
||||
S6Lifecycle["<b>Lifecycle:</b><br/>ensureInitialized()<br/>initialize()<br/>shutdown()<br/>acquireConnection()<br/>releaseConnection()"]
|
||||
S6Health["<b>Health Checks:</b><br/>runHealthCheck()<br/>runHealthChecksForServers()<br/>updateHealthCheck()<br/>getHealthCheckState()<br/>clearHealthCheck()"]
|
||||
S6Servers["<b>Server Management:</b><br/>getServers()<br/>addServer()<br/>updateServer()<br/>removeServer()<br/>getServerById()<br/>getServerDisplayName()"]
|
||||
S6Tools["<b>Tool Operations:</b><br/>getToolDefinitionsForLLM()<br/>getToolNames()<br/>hasTool()<br/>getToolServer()<br/>executeTool()<br/>executeToolByName()"]
|
||||
S6Prompts["<b>Prompt Operations:</b><br/>getAllPrompts()<br/>getPrompt()<br/>hasPromptsCapability()<br/>getPromptCompletions()"]
|
||||
end
|
||||
subgraph S7["mcpResourceStore"]
|
||||
S7State["<b>State:</b><br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]<br/>isLoading"]
|
||||
S7Resources["<b>Resource Discovery:</b><br/>setServerResources()<br/>getServerResources()<br/>getAllResourceInfos()<br/>getAllTemplateInfos()<br/>clearServerResources()"]
|
||||
S7Cache["<b>Caching:</b><br/>cacheResourceContent()<br/>getCachedContent()<br/>invalidateCache()<br/>clearCache()"]
|
||||
S7Subs["<b>Subscriptions:</b><br/>addSubscription()<br/>removeSubscription()<br/>isSubscribed()<br/>handleResourceUpdate()"]
|
||||
S7Attach["<b>Attachments:</b><br/>addAttachment()<br/>updateAttachmentContent()<br/>removeAttachment()<br/>clearAttachments()<br/>toMessageExtras()"]
|
||||
end
|
||||
|
||||
subgraph ReactiveExports["⚡ Reactive Exports"]
|
||||
direction LR
|
||||
|
|
@ -95,12 +123,19 @@ end
|
|||
RE9c["setEditModeActive()"]
|
||||
RE9d["clearEditMode()"]
|
||||
end
|
||||
subgraph AgenticExports["agenticStore"]
|
||||
REA1["agenticIsRunning()"]
|
||||
REA2["agenticCurrentTurn()"]
|
||||
REA3["agenticTotalToolCalls()"]
|
||||
REA4["agenticLastError()"]
|
||||
REA5["agenticStreamingToolCall()"]
|
||||
REA6["agenticIsAnyRunning()"]
|
||||
end
|
||||
subgraph ConvExports["conversationsStore"]
|
||||
RE10["conversations()"]
|
||||
RE11["activeConversation()"]
|
||||
RE12["activeMessages()"]
|
||||
RE13["isConversationsInitialized()"]
|
||||
RE14["usedModalities()"]
|
||||
end
|
||||
subgraph ModelsExports["modelsStore"]
|
||||
RE15["modelOptions()"]
|
||||
|
|
@ -131,6 +166,13 @@ end
|
|||
RE36["theme()"]
|
||||
RE37["isInitialized()"]
|
||||
end
|
||||
subgraph MCPExports["mcpStore / mcpResourceStore"]
|
||||
RE38["mcpResources()"]
|
||||
RE39["mcpResourceAttachments()"]
|
||||
RE40["mcpHasResourceAttachments()"]
|
||||
RE41["mcpTotalResourceCount()"]
|
||||
RE42["mcpResourcesLoading()"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
@ -138,9 +180,9 @@ end
|
|||
direction TB
|
||||
subgraph SV1["ChatService"]
|
||||
SV1Msg["<b>Messaging:</b><br/>sendMessage()"]
|
||||
SV1Stream["<b>Streaming:</b><br/>handleStreamResponse()<br/>parseSSEChunk()"]
|
||||
SV1Convert["<b>Conversion:</b><br/>convertMessageToChatData()<br/>convertExtraToApiFormat()"]
|
||||
SV1Utils["<b>Utilities:</b><br/>extractReasoningContent()<br/>getServerProps()<br/>getModels()"]
|
||||
SV1Stream["<b>Streaming:</b><br/>handleStreamResponse()<br/>handleNonStreamResponse()"]
|
||||
SV1Convert["<b>Conversion:</b><br/>convertDbMessageToApiChatMessageData()<br/>mergeToolCallDeltas()"]
|
||||
SV1Utils["<b>Utilities:</b><br/>stripReasoningContent()<br/>extractModelName()<br/>parseErrorResponse()"]
|
||||
end
|
||||
subgraph SV2["ModelsService"]
|
||||
SV2List["<b>Listing:</b><br/>list()<br/>listRouter()"]
|
||||
|
|
@ -152,7 +194,7 @@ end
|
|||
end
|
||||
subgraph SV4["DatabaseService"]
|
||||
SV4Conv["<b>Conversations:</b><br/>createConversation()<br/>getConversation()<br/>getAllConversations()<br/>updateConversation()<br/>deleteConversation()"]
|
||||
SV4Msg["<b>Messages:</b><br/>createMessageBranch()<br/>createRootMessage()<br/>getConversationMessages()<br/>updateMessage()<br/>deleteMessage()<br/>deleteMessageCascading()"]
|
||||
SV4Msg["<b>Messages:</b><br/>createMessageBranch()<br/>createRootMessage()<br/>createSystemMessage()<br/>getConversationMessages()<br/>updateMessage()<br/>deleteMessage()<br/>deleteMessageCascading()"]
|
||||
SV4Node["<b>Navigation:</b><br/>updateCurrentNode()"]
|
||||
SV4Import["<b>Import:</b><br/>importConversations()"]
|
||||
end
|
||||
|
|
@ -162,6 +204,19 @@ end
|
|||
SV5Info["<b>Info:</b><br/>getParameterInfo()<br/>canSyncParameter()<br/>getSyncableParameterKeys()<br/>validateServerParameter()"]
|
||||
SV5Diff["<b>Diff:</b><br/>createParameterDiff()"]
|
||||
end
|
||||
subgraph SV6["MCPService"]
|
||||
SV6Transport["<b>Transport:</b><br/>createTransport()<br/>WebSocket / StreamableHTTP / SSE"]
|
||||
SV6Conn["<b>Connection:</b><br/>connect()<br/>disconnect()"]
|
||||
SV6Tools["<b>Tools:</b><br/>listTools()<br/>callTool()"]
|
||||
SV6Prompts["<b>Prompts:</b><br/>listPrompts()<br/>getPrompt()"]
|
||||
SV6Resources["<b>Resources:</b><br/>listResources()<br/>listResourceTemplates()<br/>readResource()<br/>subscribeResource()<br/>unsubscribeResource()"]
|
||||
SV6Complete["<b>Completions:</b><br/>complete()"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph ExternalMCP["🔌 External MCP Servers"]
|
||||
EXT1["MCP Server 1<br/>(WebSocket/StreamableHTTP/SSE)"]
|
||||
EXT2["MCP Server N"]
|
||||
end
|
||||
|
||||
subgraph Storage["💾 Storage"]
|
||||
|
|
@ -171,6 +226,7 @@ end
|
|||
ST5["LocalStorage"]
|
||||
ST6["config"]
|
||||
ST7["userOverrides"]
|
||||
ST8["mcpServers"]
|
||||
end
|
||||
|
||||
subgraph APIs["🌐 llama-server API"]
|
||||
|
|
@ -185,6 +241,9 @@ end
|
|||
R2 --> C_Screen
|
||||
RL --> C_Sidebar
|
||||
|
||||
%% Layout runs MCP health checks on startup
|
||||
RL --> S6
|
||||
|
||||
%% Component hierarchy
|
||||
C_Screen --> C_Form & C_Messages & C_Settings
|
||||
C_Messages --> C_Message
|
||||
|
|
@ -194,8 +253,15 @@ end
|
|||
C_MessageEditForm --> C_Attach
|
||||
C_Form --> C_ModelsSelector
|
||||
C_Form --> C_Attach
|
||||
C_Form --> C_McpServersSelector
|
||||
C_Message --> C_Attach
|
||||
|
||||
%% MCP Components hierarchy
|
||||
C_Settings --> C_McpSettings
|
||||
C_McpSettings --> C_McpServerCard
|
||||
C_McpServerCard --> C_McpResourceBrowser
|
||||
C_McpResourceBrowser --> C_McpResourcePreview
|
||||
|
||||
%% Components use Hooks
|
||||
C_Form --> H1
|
||||
C_Message --> H1 & H2
|
||||
|
|
@ -210,17 +276,29 @@ end
|
|||
C_Screen --> S1 & S2
|
||||
C_Messages --> S2
|
||||
C_Message --> S1 & S2 & S3
|
||||
C_Form --> S1 & S3
|
||||
C_Form --> S1 & S3 & S6
|
||||
C_Sidebar --> S2
|
||||
C_ModelsSelector --> S3 & S4
|
||||
C_Settings --> S5
|
||||
C_McpSettings --> S6
|
||||
C_McpServerCard --> S6
|
||||
C_McpResourceBrowser --> S6 & S7
|
||||
C_McpServersSelector --> S6
|
||||
|
||||
%% Stores export Reactive State
|
||||
S1 -. exports .-> ChatExports
|
||||
SA -. exports .-> AgenticExports
|
||||
S2 -. exports .-> ConvExports
|
||||
S3 -. exports .-> ModelsExports
|
||||
S4 -. exports .-> ServerExports
|
||||
S5 -. exports .-> SettingsExports
|
||||
S6 -. exports .-> MCPExports
|
||||
S7 -. exports .-> MCPExports
|
||||
|
||||
%% chatStore → agenticStore (agentic loop orchestration)
|
||||
S1 --> SA
|
||||
SA --> SV1
|
||||
SA --> S6
|
||||
|
||||
%% Stores use Services
|
||||
S1 --> SV1 & SV4
|
||||
|
|
@ -228,28 +306,35 @@ end
|
|||
S3 --> SV2 & SV3
|
||||
S4 --> SV3
|
||||
S5 --> SV5
|
||||
S6 --> SV6
|
||||
S7 --> SV6
|
||||
|
||||
%% Services to Storage
|
||||
SV4 --> ST1
|
||||
ST1 --> ST2 & ST3
|
||||
SV5 --> ST5
|
||||
ST5 --> ST6 & ST7
|
||||
ST5 --> ST6 & ST7 & ST8
|
||||
|
||||
%% Services to APIs
|
||||
SV1 --> API1
|
||||
SV2 --> API3 & API4
|
||||
SV3 --> API2
|
||||
|
||||
%% MCP → External Servers
|
||||
SV6 --> EXT1 & EXT2
|
||||
|
||||
%% Styling
|
||||
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
|
||||
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
|
||||
classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px
|
||||
classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px
|
||||
classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px
|
||||
classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px
|
||||
classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px
|
||||
classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px
|
||||
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
|
||||
classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px
|
||||
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
|
||||
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
|
||||
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
|
||||
|
||||
|
|
@ -257,23 +342,32 @@ end
|
|||
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle
|
||||
class C_ModelsSelector,C_Settings componentStyle
|
||||
class C_Attach componentStyle
|
||||
class H1,H2,H3 methodStyle
|
||||
class LayoutComponents,ChatUIComponents componentGroupStyle
|
||||
class Hooks storeStyle
|
||||
class S1,S2,S3,S4,S5 storeStyle
|
||||
class S1State,S2State,S3State,S4State,S5State stateStyle
|
||||
class C_McpSettings,C_McpServerCard,C_McpResourceBrowser,C_McpResourcePreview,C_McpServersSelector componentStyle
|
||||
class H1,H2,H3 hookStyle
|
||||
class LayoutComponents,ChatUIComponents,MCPComponents componentGroupStyle
|
||||
class Hooks hookStyle
|
||||
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
|
||||
classDef agenticMethodStyle fill:#c5cae9,stroke:#283593,stroke-width:1px
|
||||
|
||||
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
|
||||
class S1State,S2State,S3State,S4State,S5State,SAState,S6State,S7State stateStyle
|
||||
class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle
|
||||
class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2Modal,S2Export,S2Utils methodStyle
|
||||
class SASession,SAConfig,SAFlow methodStyle
|
||||
class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2McpOverrides,S2Export,S2Utils methodStyle
|
||||
class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle
|
||||
class S4Getters,S4Data,S4Utils methodStyle
|
||||
class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle
|
||||
class ChatExports,ConvExports,ModelsExports,ServerExports,SettingsExports reactiveStyle
|
||||
class SV1,SV2,SV3,SV4,SV5 serviceStyle
|
||||
class S6Lifecycle,S6Health,S6Servers,S6Tools,S6Prompts methodStyle
|
||||
class S7Resources,S7Cache,S7Subs,S7Attach methodStyle
|
||||
class ChatExports,AgenticExports,ConvExports,ModelsExports,ServerExports,SettingsExports,MCPExports reactiveStyle
|
||||
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
|
||||
class SV6Transport,SV6Conn,SV6Tools,SV6Prompts,SV6Resources,SV6Complete serviceMStyle
|
||||
class EXT1,EXT2 externalStyle
|
||||
class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle
|
||||
class SV2List,SV2LoadUnload,SV2Status serviceMStyle
|
||||
class SV3Fetch serviceMStyle
|
||||
class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle
|
||||
class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle
|
||||
class ST1,ST2,ST3,ST5,ST6,ST7 storageStyle
|
||||
class ST1,ST2,ST3,ST5,ST6,ST7,ST8 storageStyle
|
||||
class API1,API2,API3,API4 apiStyle
|
||||
```
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
sequenceDiagram
|
||||
participant UI as 🧩 ChatForm / ChatMessage
|
||||
participant chatStore as 🗄️ chatStore
|
||||
participant agenticStore as 🗄️ agenticStore
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant settingsStore as 🗄️ settingsStore
|
||||
participant mcpStore as 🗄️ mcpStore
|
||||
participant ChatSvc as ⚙️ ChatService
|
||||
participant DbSvc as ⚙️ DatabaseService
|
||||
participant API as 🌐 /v1/chat/completions
|
||||
|
|
@ -25,6 +27,9 @@ sequenceDiagram
|
|||
Note over convStore: → see conversations-flow.mmd
|
||||
end
|
||||
|
||||
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
|
||||
Note right of mcpStore: Converts pending MCP resource<br/>attachments into message extras
|
||||
|
||||
chatStore->>chatStore: addMessage("user", content, extras)
|
||||
chatStore->>DbSvc: createMessageBranch(userMsg, parentId)
|
||||
chatStore->>convStore: addMessageToActive(userMsg)
|
||||
|
|
@ -38,7 +43,7 @@ sequenceDiagram
|
|||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🌊 STREAMING
|
||||
Note over UI,API: 🌊 STREAMING (with agentic flow detection)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
activate chatStore
|
||||
|
|
@ -52,10 +57,17 @@ sequenceDiagram
|
|||
chatStore->>chatStore: getApiOptions()
|
||||
Note right of chatStore: Merge from settingsStore.config:<br/>temperature, max_tokens, top_p, etc.
|
||||
|
||||
chatStore->>ChatSvc: sendMessage(messages, options, signal)
|
||||
alt agenticConfig.enabled && mcpStore has connected servers
|
||||
chatStore->>agenticStore: runAgenticFlow(convId, messages, assistantMsg, options, signal)
|
||||
Note over agenticStore: Multi-turn agentic loop:<br/>1. Call ChatService.sendMessage()<br/>2. If response has tool_calls → execute via mcpStore<br/>3. Append tool results as messages<br/>4. Loop until no more tool_calls or maxTurns<br/>→ see agentic flow details below
|
||||
agenticStore-->>chatStore: final response with timings
|
||||
else standard (non-agentic) flow
|
||||
chatStore->>ChatSvc: sendMessage(messages, options, signal)
|
||||
end
|
||||
|
||||
activate ChatSvc
|
||||
|
||||
ChatSvc->>ChatSvc: convertMessageToChatData(messages)
|
||||
ChatSvc->>ChatSvc: convertDbMessageToApiChatMessageData(messages)
|
||||
Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]<br/>Process attachments (images, PDFs, audio)
|
||||
|
||||
ChatSvc->>API: POST /v1/chat/completions
|
||||
|
|
@ -63,7 +75,7 @@ sequenceDiagram
|
|||
|
||||
loop SSE chunks
|
||||
API-->>ChatSvc: data: {"choices":[{"delta":{...}}]}
|
||||
ChatSvc->>ChatSvc: parseSSEChunk(line)
|
||||
ChatSvc->>ChatSvc: handleStreamResponse(response)
|
||||
|
||||
alt content chunk
|
||||
ChatSvc-->>chatStore: onChunk(content)
|
||||
|
|
@ -154,12 +166,15 @@ sequenceDiagram
|
|||
Note over UI,API: ✏️ EDIT USER MESSAGE
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: editUserMessagePreserveResponses(msgId, newContent)
|
||||
UI->>chatStore: editMessageWithBranching(msgId, newContent, extras)
|
||||
activate chatStore
|
||||
chatStore->>chatStore: Get parent of target message
|
||||
chatStore->>DbSvc: createMessageBranch(editedMsg, parentId)
|
||||
chatStore->>convStore: refreshActiveMessages()
|
||||
Note right of chatStore: Creates new branch, original preserved
|
||||
chatStore->>chatStore: createAssistantMessage(editedMsg.id)
|
||||
chatStore->>chatStore: streamChatCompletion(...)
|
||||
Note right of chatStore: Automatically regenerates response
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -171,4 +186,43 @@ sequenceDiagram
|
|||
Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message}
|
||||
chatStore->>convStore: removeMessageAtIndex(failedMsgIdx)
|
||||
chatStore->>DbSvc: deleteMessage(failedMsgId)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🤖 AGENTIC LOOP (when agenticConfig.enabled)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over agenticStore: agenticStore.runAgenticFlow(convId, messages, assistantMsg, options, signal)
|
||||
activate agenticStore
|
||||
agenticStore->>agenticStore: getSession(convId) or create new
|
||||
agenticStore->>agenticStore: updateSession(turn: 0, running: true)
|
||||
|
||||
loop executeAgenticLoop (until no tool_calls or maxTurns)
|
||||
agenticStore->>agenticStore: turn++
|
||||
agenticStore->>ChatSvc: sendMessage(messages, options, signal)
|
||||
ChatSvc->>API: POST /v1/chat/completions
|
||||
API-->>ChatSvc: response with potential tool_calls
|
||||
ChatSvc-->>agenticStore: onComplete(content, reasoning, timings, toolCalls)
|
||||
|
||||
alt response has tool_calls
|
||||
agenticStore->>agenticStore: normalizeToolCalls(toolCalls)
|
||||
loop for each tool_call
|
||||
agenticStore->>agenticStore: updateSession(streamingToolCall)
|
||||
agenticStore->>mcpStore: executeTool(mcpCall, signal)
|
||||
mcpStore-->>agenticStore: tool result
|
||||
agenticStore->>agenticStore: extractBase64Attachments(result)
|
||||
agenticStore->>agenticStore: emitToolCallResult(convId, ...)
|
||||
agenticStore->>convStore: addMessageToActive(toolResultMsg)
|
||||
agenticStore->>DbSvc: createMessageBranch(toolResultMsg)
|
||||
end
|
||||
agenticStore->>agenticStore: Create new assistantMsg for next turn
|
||||
Note right of agenticStore: Continue loop with updated messages
|
||||
else no tool_calls (final response)
|
||||
agenticStore->>agenticStore: buildFinalTimings(allTurns)
|
||||
Note right of agenticStore: Break loop, return final response
|
||||
end
|
||||
end
|
||||
|
||||
agenticStore->>agenticStore: updateSession(running: false)
|
||||
agenticStore-->>chatStore: final content, timings, model
|
||||
deactivate agenticStore
|
||||
```
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ sequenceDiagram
|
|||
participant DbSvc as ⚙️ DatabaseService
|
||||
participant IDB as 💾 IndexedDB
|
||||
|
||||
Note over convStore: State:<br/>conversations: DatabaseConversation[]<br/>activeConversation: DatabaseConversation | null<br/>activeMessages: DatabaseMessage[]<br/>isInitialized: boolean<br/>usedModalities: $derived({vision, audio})
|
||||
Note over convStore: State:<br/>conversations: DatabaseConversation[]<br/>activeConversation: DatabaseConversation | null<br/>activeMessages: DatabaseMessage[]<br/>isInitialized: boolean<br/>pendingMcpServerOverrides: Map<string, McpServerOverride>
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 🚀 INITIALIZATION
|
||||
|
|
@ -37,6 +37,13 @@ sequenceDiagram
|
|||
convStore->>convStore: conversations.unshift(conversation)
|
||||
convStore->>convStore: activeConversation = $state(conversation)
|
||||
convStore->>convStore: activeMessages = $state([])
|
||||
|
||||
alt pendingMcpServerOverrides has entries
|
||||
loop each pending override
|
||||
convStore->>DbSvc: Store MCP server override for new conversation
|
||||
end
|
||||
convStore->>convStore: clearPendingMcpServerOverrides()
|
||||
end
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -58,8 +65,7 @@ sequenceDiagram
|
|||
Note right of convStore: Filter to show only current branch path
|
||||
convStore->>convStore: activeMessages = $state(filtered)
|
||||
|
||||
convStore->>chatStore: syncLoadingStateForChat(convId)
|
||||
Note right of chatStore: Sync isLoading/currentResponse if streaming
|
||||
Note right of convStore: Route (+page.svelte) then calls:<br/>chatStore.syncLoadingStateForChat(convId)
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -121,16 +127,36 @@ sequenceDiagram
|
|||
end
|
||||
deactivate convStore
|
||||
|
||||
UI->>convStore: deleteAll()
|
||||
activate convStore
|
||||
convStore->>DbSvc: Delete all conversations and messages
|
||||
convStore->>convStore: conversations = []
|
||||
convStore->>convStore: clearActiveConversation()
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 📊 MODALITY TRACKING
|
||||
Note over UI,IDB: <EFBFBD> MCP SERVER PER-CHAT OVERRIDES
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over convStore: usedModalities = $derived.by(() => {<br/> calculateModalitiesFromMessages(activeMessages)<br/>})
|
||||
Note over convStore: Conversations can override which MCP servers are enabled.
|
||||
Note over convStore: Uses pendingMcpServerOverrides before conversation<br/>is created, then persists to conversation metadata.
|
||||
|
||||
Note over convStore: Scans activeMessages for attachments:<br/>- IMAGE → vision: true<br/>- PDF (processedAsImages) → vision: true<br/>- AUDIO → audio: true
|
||||
UI->>convStore: setMcpServerOverride(convId, serverName, override)
|
||||
Note right of convStore: override = {enabled: boolean}
|
||||
|
||||
UI->>convStore: getModalitiesUpToMessage(msgId)
|
||||
Note right of convStore: Used for regeneration validation<br/>Only checks messages BEFORE target
|
||||
UI->>convStore: toggleMcpServerForChat(convId, serverName, enabled)
|
||||
activate convStore
|
||||
convStore->>convStore: setMcpServerOverride(convId, serverName, {enabled})
|
||||
deactivate convStore
|
||||
|
||||
UI->>convStore: isMcpServerEnabledForChat(convId, serverName)
|
||||
Note right of convStore: Check override → fall back to global MCP config
|
||||
|
||||
UI->>convStore: getAllMcpServerOverrides(convId)
|
||||
Note right of convStore: Returns all overrides for a conversation
|
||||
|
||||
UI->>convStore: removeMcpServerOverride(convId, serverName)
|
||||
UI->>convStore: getMcpServerOverride(convId, serverName)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 📤 EXPORT / 📥 IMPORT
|
||||
|
|
@ -148,8 +174,10 @@ sequenceDiagram
|
|||
UI->>convStore: importConversations(file)
|
||||
activate convStore
|
||||
convStore->>convStore: Parse JSON file
|
||||
convStore->>convStore: importConversationsData(parsed)
|
||||
convStore->>DbSvc: importConversations(parsed)
|
||||
DbSvc->>IDB: Bulk INSERT conversations + messages
|
||||
Note right of DbSvc: Skips duplicate conversations<br/>(checks existing by ID)
|
||||
DbSvc->>IDB: INSERT conversations + messages (skip existing)
|
||||
convStore->>convStore: loadConversations()
|
||||
deactivate convStore
|
||||
```
|
||||
|
|
|
|||
|
|
@ -66,6 +66,14 @@ sequenceDiagram
|
|||
DbSvc-->>Store: rootMessageId
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: createSystemMessage(convId, content, parentId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Create message {role: "system", parent: parentId}
|
||||
DbSvc->>Dexie: db.messages.add(systemMsg)
|
||||
Dexie->>IDB: INSERT
|
||||
DbSvc-->>Store: DatabaseMessage
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: createMessageBranch(message, parentId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Generate UUID for new message
|
||||
|
|
@ -116,6 +124,13 @@ sequenceDiagram
|
|||
end
|
||||
DbSvc->>Dexie: db.messages.delete(msgId)
|
||||
Dexie->>IDB: DELETE target message
|
||||
|
||||
alt target message has a parent
|
||||
DbSvc->>Dexie: db.messages.get(parentId)
|
||||
DbSvc->>DbSvc: parent.children.filter(id !== msgId)
|
||||
DbSvc->>Dexie: db.messages.update(parentId, {children})
|
||||
Note right of DbSvc: Remove deleted message from parent's children[]
|
||||
end
|
||||
deactivate DbSvc
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -125,12 +140,16 @@ sequenceDiagram
|
|||
Store->>DbSvc: importConversations(data)
|
||||
activate DbSvc
|
||||
loop each conversation in data
|
||||
DbSvc->>DbSvc: Generate new UUIDs (avoid conflicts)
|
||||
DbSvc->>Dexie: db.conversations.add(conversation)
|
||||
Dexie->>IDB: INSERT conversation
|
||||
loop each message
|
||||
DbSvc->>Dexie: db.messages.add(message)
|
||||
Dexie->>IDB: INSERT message
|
||||
DbSvc->>Dexie: db.conversations.get(conv.id)
|
||||
alt conversation already exists
|
||||
Note right of DbSvc: Skip duplicate (keep existing)
|
||||
else conversation is new
|
||||
DbSvc->>Dexie: db.conversations.add(conversation)
|
||||
Dexie->>IDB: INSERT conversation
|
||||
loop each message
|
||||
DbSvc->>Dexie: db.messages.add(message)
|
||||
Dexie->>IDB: INSERT message
|
||||
end
|
||||
end
|
||||
end
|
||||
deactivate DbSvc
|
||||
|
|
|
|||
|
|
@ -0,0 +1,226 @@
|
|||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 McpServersSettings / ChatForm
|
||||
participant chatStore as 🗄️ chatStore
|
||||
participant mcpStore as 🗄️ mcpStore
|
||||
participant mcpResStore as 🗄️ mcpResourceStore
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant MCPSvc as ⚙️ MCPService
|
||||
participant LS as 💾 LocalStorage
|
||||
participant ExtMCP as 🔌 External MCP Server
|
||||
|
||||
Note over mcpStore: State:<br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)<br/>serverConfigs (Map)
|
||||
|
||||
Note over mcpResStore: State:<br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🚀 INITIALIZATION (App Startup)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: ensureInitialized()
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>LS: get(MCP_SERVERS_LOCALSTORAGE_KEY)
|
||||
LS-->>mcpStore: MCPServerSettingsEntry[]
|
||||
|
||||
mcpStore->>mcpStore: parseServerSettings(servers)
|
||||
Note right of mcpStore: Filter enabled servers<br/>Build MCPServerConfig objects<br/>Per-chat overrides checked via convStore
|
||||
|
||||
loop For each enabled server
|
||||
mcpStore->>mcpStore: runHealthCheck(serverId)
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, CONNECTING)
|
||||
|
||||
mcpStore->>MCPSvc: connect(serverName, config, clientInfo, capabilities, onPhase)
|
||||
activate MCPSvc
|
||||
|
||||
MCPSvc->>MCPSvc: createTransport(config)
|
||||
Note right of MCPSvc: WebSocket / StreamableHTTP / SSE<br/>with optional CORS proxy
|
||||
|
||||
MCPSvc->>ExtMCP: Transport handshake
|
||||
ExtMCP-->>MCPSvc: Connection established
|
||||
|
||||
MCPSvc->>ExtMCP: Initialize request
|
||||
Note right of ExtMCP: Exchange capabilities<br/>Server info, protocol version
|
||||
|
||||
ExtMCP-->>MCPSvc: InitializeResult (serverInfo, capabilities)
|
||||
|
||||
MCPSvc->>ExtMCP: listTools()
|
||||
ExtMCP-->>MCPSvc: Tool[]
|
||||
|
||||
MCPSvc-->>mcpStore: MCPConnection
|
||||
deactivate MCPSvc
|
||||
|
||||
mcpStore->>mcpStore: connections.set(serverName, connection)
|
||||
mcpStore->>mcpStore: indexTools(connection.tools, serverName)
|
||||
Note right of mcpStore: toolsIndex.set(toolName, serverName)<br/>Handle name conflicts with prefixes
|
||||
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
|
||||
mcpStore->>mcpStore: _connectedServers.push(serverName)
|
||||
|
||||
alt Server supports resources
|
||||
mcpStore->>MCPSvc: listAllResources(connection)
|
||||
MCPSvc->>ExtMCP: listResources()
|
||||
ExtMCP-->>MCPSvc: MCPResource[]
|
||||
MCPSvc-->>mcpStore: resources
|
||||
|
||||
mcpStore->>MCPSvc: listAllResourceTemplates(connection)
|
||||
MCPSvc->>ExtMCP: listResourceTemplates()
|
||||
ExtMCP-->>MCPSvc: MCPResourceTemplate[]
|
||||
MCPSvc-->>mcpStore: templates
|
||||
|
||||
mcpStore->>mcpResStore: setServerResources(serverName, resources, templates)
|
||||
end
|
||||
end
|
||||
|
||||
mcpStore->>mcpStore: _isInitializing = false
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🔧 TOOL EXECUTION (Chat with Tools)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: executeTool(mcpCall: MCPToolCall, signal?)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: toolsIndex.get(mcpCall.function.name)
|
||||
Note right of mcpStore: Resolve serverName from toolsIndex<br/>MCPToolCall = {id, type, function: {name, arguments}}
|
||||
|
||||
mcpStore->>mcpStore: acquireConnection()
|
||||
Note right of mcpStore: activeFlowCount++<br/>Prevent shutdown during execution
|
||||
|
||||
mcpStore->>mcpStore: connection = connections.get(serverName)
|
||||
|
||||
mcpStore->>MCPSvc: callTool(connection, {name, arguments}, signal)
|
||||
activate MCPSvc
|
||||
|
||||
MCPSvc->>MCPSvc: throwIfAborted(signal)
|
||||
MCPSvc->>ExtMCP: callTool(name, arguments)
|
||||
|
||||
alt Tool execution success
|
||||
ExtMCP-->>MCPSvc: ToolCallResult (content, isError)
|
||||
MCPSvc->>MCPSvc: formatToolResult(result)
|
||||
Note right of MCPSvc: Handle text, image (base64),<br/>embedded resource content
|
||||
MCPSvc-->>mcpStore: ToolExecutionResult
|
||||
else Tool execution error
|
||||
ExtMCP-->>MCPSvc: Error
|
||||
MCPSvc-->>mcpStore: throw Error
|
||||
else Aborted
|
||||
MCPSvc-->>mcpStore: throw AbortError
|
||||
end
|
||||
|
||||
deactivate MCPSvc
|
||||
|
||||
mcpStore->>mcpStore: releaseConnection()
|
||||
Note right of mcpStore: activeFlowCount--
|
||||
|
||||
mcpStore-->>UI: ToolExecutionResult
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: <20> RESOURCE ATTACHMENT CONSUMPTION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
|
||||
activate mcpStore
|
||||
mcpStore->>mcpResStore: getAttachments()
|
||||
mcpResStore-->>mcpStore: MCPResourceAttachment[]
|
||||
mcpStore->>mcpStore: Convert attachments to message extras
|
||||
mcpStore->>mcpResStore: clearAttachments()
|
||||
mcpStore-->>chatStore: MessageExtra[] (for user message)
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: <20>📝 PROMPT OPERATIONS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: getAllPrompts()
|
||||
activate mcpStore
|
||||
|
||||
loop For each connected server with prompts capability
|
||||
mcpStore->>MCPSvc: listPrompts(connection)
|
||||
MCPSvc->>ExtMCP: listPrompts()
|
||||
ExtMCP-->>MCPSvc: Prompt[]
|
||||
MCPSvc-->>mcpStore: prompts
|
||||
end
|
||||
|
||||
mcpStore-->>UI: MCPPromptInfo[] (with serverName)
|
||||
deactivate mcpStore
|
||||
|
||||
UI->>mcpStore: getPrompt(serverName, promptName, args?)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>MCPSvc: getPrompt(connection, name, args)
|
||||
MCPSvc->>ExtMCP: getPrompt({name, arguments})
|
||||
ExtMCP-->>MCPSvc: GetPromptResult (messages)
|
||||
MCPSvc-->>mcpStore: GetPromptResult
|
||||
|
||||
mcpStore-->>UI: GetPromptResult
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 📁 RESOURCE OPERATIONS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpResStore: addAttachment(resourceInfo)
|
||||
activate mcpResStore
|
||||
mcpResStore->>mcpResStore: Create MCPResourceAttachment (loading: true)
|
||||
mcpResStore-->>UI: attachment
|
||||
|
||||
UI->>mcpStore: readResource(serverName, uri)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>MCPSvc: readResource(connection, uri)
|
||||
MCPSvc->>ExtMCP: readResource({uri})
|
||||
ExtMCP-->>MCPSvc: MCPReadResourceResult (contents)
|
||||
MCPSvc-->>mcpStore: contents
|
||||
|
||||
mcpStore-->>UI: MCPResourceContent[]
|
||||
deactivate mcpStore
|
||||
|
||||
UI->>mcpResStore: updateAttachmentContent(attachmentId, content)
|
||||
mcpResStore->>mcpResStore: cacheResourceContent(resource, content)
|
||||
deactivate mcpResStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🔄 AUTO-RECONNECTION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over mcpStore: On WebSocket close or connection error:
|
||||
mcpStore->>mcpStore: autoReconnect(serverName, attempt)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: Calculate backoff delay
|
||||
Note right of mcpStore: delay = min(30s, 1s * 2^attempt)
|
||||
|
||||
mcpStore->>mcpStore: Wait for delay
|
||||
mcpStore->>mcpStore: reconnectServer(serverName)
|
||||
|
||||
alt Reconnection success
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
|
||||
else Max attempts reached
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, ERROR)
|
||||
end
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🛑 SHUTDOWN
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: shutdown()
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: Wait for activeFlowCount == 0
|
||||
|
||||
loop For each connection
|
||||
mcpStore->>MCPSvc: disconnect(connection)
|
||||
MCPSvc->>MCPSvc: transport.onclose = undefined
|
||||
MCPSvc->>ExtMCP: close()
|
||||
end
|
||||
|
||||
mcpStore->>mcpStore: connections.clear()
|
||||
mcpStore->>mcpStore: toolsIndex.clear()
|
||||
mcpStore->>mcpStore: _connectedServers = []
|
||||
|
||||
mcpStore->>mcpResStore: clear()
|
||||
deactivate mcpStore
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue