diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 1b431ed15b..00834ba0f3 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -8,109 +8,11 @@ #include "nlohmann/json.hpp" #include "peg-parser.h" -#include #include #include using json = nlohmann::ordered_json; -namespace { - -// Gemma4-specific PEG builder extending the standard chat builder. -// Adds value type parsers that use <|\"|> as string delimiters -// instead of JSON's double quotes, and disables json-to-schema -// conversion for these types. -class common_peg_gemma4_builder { - common_chat_peg_builder & p_; - static constexpr const char * QUOTE = "<|\"|>"; - -public: - explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {} - - common_peg_parser gemma4_string() { - return p_.rule("gemma4-string", [&]() { - return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE); - }); - } - - common_peg_parser gemma4_number() { - return p_.rule("gemma4-number", [&]() { - auto digit1_9 = p_.chars("[1-9]", 1, 1); - auto digits = p_.chars("[0-9]"); - auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})}); - auto frac = p_.sequence({p_.literal("."), digits}); - auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}), - p_.optional(p_.chars("[+-]", 1, 1)), digits}); - auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1)); - return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac), - p_.optional(exp), not_number_continuation}); - }); - } - - common_peg_parser gemma4_bool() { - return p_.rule("gemma4-bool", [&]() { - return p_.choice({p_.literal("true"), p_.literal("false")}); - }); - } - - common_peg_parser gemma4_null() { - return p_.rule("gemma4-null", [&]() { - return p_.literal("null"); - }); - } - - common_peg_parser gemma4_dict() { - return p_.rule("gemma4-dict", [&]() { - auto ws = p_.space(); - auto key = p_.until(":"); - auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()}); - auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))}); - return p_.sequence({ - p_.literal("{"), ws, - p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})}) - }); - }); - } - - common_peg_parser gemma4_array() { - return p_.rule("gemma4-array", [&]() { - auto ws = p_.space(); - auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))}); - return p_.sequence({ - p_.literal("["), ws, - p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})}) - }); - }); - } - - common_peg_parser gemma4_value() { - return p_.rule("gemma4-value", [&]() { - return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(), - gemma4_number(), gemma4_bool(), gemma4_null()}); - }); - } - - // Select the appropriate value parser based on JSON schema type. - // Does NOT use schema() - the gemma4 types are pure PEG without - // JSON schema metadata, so GBNF is generated directly from the - // PEG structure. - common_peg_parser gemma4_value_for_type(const json & schema) { - if (!schema.contains("type") || !schema.at("type").is_string()) { - return gemma4_value(); - } - std::string type = schema.at("type").get(); - if (type == "string") { return gemma4_string(); } - if (type == "number") { return gemma4_number(); } - if (type == "integer") { return gemma4_number(); } - if (type == "boolean") { return gemma4_bool(); } - if (type == "object") { return gemma4_dict(); } - if (type == "array") { return gemma4_array(); } - return gemma4_value(); - } -}; - -} // anonymous namespace - // Helper to iterate over tools/functions static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { @@ -142,9 +44,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & // Create the result structure common_chat_params data; data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT) - ? COMMON_CHAT_FORMAT_PEG_GEMMA4 - : COMMON_CHAT_FORMAT_PEG_NATIVE; + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = autoparser.preserved_tokens; auto parser = autoparser.build_parser(inputs); @@ -271,8 +171,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const return build_tool_parser_tag_json(ctx); case tool_format::TAG_WITH_TAGGED: return build_tool_parser_tag_tagged(ctx); - case tool_format::TAG_WITH_GEMMA4_DICT: - return build_tool_parser_tag_gemma4_dict(ctx); default: LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. " "Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or " @@ -586,145 +484,4 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte p.end(); } -common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const { - auto & p = ctx.p; - const auto & inputs = ctx.inputs; - bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - common_peg_gemma4_builder g4(p); - static const std::string QUOTE = "<|\"|>"; - - common_peg_parser tool_choice = p.choice(); - - foreach_function(inputs.tools, [&](const json & tool) { - const auto & func = tool.at("function"); - std::string name = func.at("name"); - const auto & params = func.at("parameters"); - - if (!params.contains("properties") || !params.at("properties").is_object()) { - auto func_parser = p.atomic( - p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) + - p.tool_args(p.eps()) + - p.tool_close(p.literal("}"))); - tool_choice |= p.rule("tool-" + name, func_parser); - return; - } - - const auto & properties = params.at("properties"); - std::set required; - if (params.contains("required") && params.at("required").is_array()) { - params.at("required").get_to(required); - } - - // Build per-argument parsers, sorted alphabetically (matching template's dictsort) - struct arg_entry { - std::string param_name; - common_peg_parser parser; - }; - std::vector arg_entries; - - for (const auto & [param_name, param_schema] : properties.items()) { - std::string type = "object"; - if (param_schema.contains("type")) { - const auto & type_v = param_schema.at("type"); - if (type_v.is_string()) { - type_v.get_to(type); - } else if (type_v.is_array()) { - // Handle nullable types like ["string", "null"] - for (const auto & t : type_v) { - if (t.is_string() && t.get() != "null") { - type = t.get(); - break; - } - } - } - } - // Infer string type from enum values when type is unspecified - if (type == "object" && param_schema.contains("enum")) { - const auto & enum_vals = param_schema.at("enum"); - if (enum_vals.is_array()) { - for (const auto & v : enum_vals) { - if (v.is_string()) { - type = "string"; - break; - } - } - } - } - - common_peg_parser value_parser = p.eps(); - if (type == "string") { - // String values are delimited by <|"|>...<|"|> - value_parser = - p.literal(QUOTE) + - p.tool_arg_string_value(p.schema(p.until(QUOTE), - "tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) + - p.literal(QUOTE); - } else if (type == "number" || type == "integer") { - value_parser = p.tool_arg_value(g4.gemma4_number()); - } else if (type == "boolean") { - value_parser = p.tool_arg_value(g4.gemma4_bool()); - } else if (type == "null") { - value_parser = p.tool_arg_value(g4.gemma4_null()); - } else if (type == "object") { - value_parser = p.tool_arg_value(g4.gemma4_dict()); - } else if (type == "array") { - value_parser = p.tool_arg_value(g4.gemma4_array()); - } else { - value_parser = p.tool_arg_value(g4.gemma4_value()); - } - - auto arg = p.tool_arg( - p.tool_arg_open(p.tool_arg_name(p.literal(param_name)) + p.literal(":")) + - value_parser + - p.tool_arg_close(p.eps())); - - arg_entries.push_back({param_name, p.rule("tool-" + name + "-arg-" + param_name, arg)}); - } - - // Sort alphabetically to match Jinja's dictsort - std::sort(arg_entries.begin(), arg_entries.end(), [](const auto & a, const auto & b) { - return a.param_name < b.param_name; - }); - - // Build arg sequence: any arg, then zero-or-more comma-separated additional args - common_peg_parser args_seq = p.eps(); - if (!arg_entries.empty()) { - common_peg_parser any_arg = p.choice(); - for (auto & entry : arg_entries) { - any_arg |= entry.parser; - } - args_seq = p.optional( - any_arg + p.repeat(p.literal(",") + any_arg, 0, (int) arg_entries.size() - 1)); - } - - // Full parser: call:name{args} - auto func_parser = p.atomic( - p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) + - p.tool_args(args_seq) + - p.tool_close(p.literal("}"))); - - tool_choice |= p.rule("tool-" + name, func_parser); - }); - - // Wrap each call in <|tool_call>... - auto wrapped_call = p.literal(format.per_call_start) + tool_choice + p.literal(format.per_call_end); - - common_peg_parser tool_calls = p.eps(); - if (inputs.parallel_tool_calls) { - tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); - } else { - tool_calls = p.trigger_rule("tool-call", wrapped_call); - } - - if (!force_tools) { - tool_calls = p.optional(tool_calls); - } - - auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start }); - return ctx.reasoning_parser + - (force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) + - tool_calls + p.end(); -} - } // namespace autoparser diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 2168bb05ed..99dd9f063c 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -145,7 +145,6 @@ enum class tool_format { JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}} TAG_WITH_JSON, // Tag-based with JSON args: {...} TAG_WITH_TAGGED, // Tag-based with tagged args: value - TAG_WITH_GEMMA4_DICT, // Gemma4 custom dict: <|tool_call>call:name{key:<|"|>val<|"|>} }; inline std::ostream & operator<<(std::ostream & os, const tool_format & format) { @@ -158,8 +157,6 @@ inline std::ostream & operator<<(std::ostream & os, const tool_format & format) return os << "TAG_WITH_JSON"; case tool_format::TAG_WITH_TAGGED: return os << "TAG_WITH_TAGGED"; - case tool_format::TAG_WITH_GEMMA4_DICT: - return os << "TAG_WITH_GEMMA4_DICT"; default: return os << "UNKNOWN"; } @@ -363,7 +360,6 @@ struct analyze_tools : analyze_base { const common_peg_parser & call_id_section, bool have_call_id, const common_peg_parser & args, std::optional atomic_peek) const; - common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const; }; // ============================================================================ diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index 8288296631..fa3e368098 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -95,34 +95,6 @@ static std::vectorcall:name{key:<|"|>val<|"|>} - [](const common_chat_template & tmpl, autoparser & analysis) -> void { - if (tmpl.src.find("'<|tool_call>call:'") != std::string::npos) { - analysis.tools.format.mode = tool_format::TAG_WITH_GEMMA4_DICT; - analysis.tools.format.per_call_start = "<|tool_call>"; - analysis.tools.format.per_call_end = ""; - analysis.tools.format.section_start = ""; - analysis.tools.format.section_end = ""; - analysis.tools.function.name_prefix = "call:"; - analysis.tools.function.name_suffix = ""; - analysis.tools.arguments.start = "{"; - analysis.tools.arguments.end = "}"; - analysis.tools.arguments.name_prefix = ""; - analysis.tools.arguments.name_suffix = ":"; - analysis.tools.arguments.separator = ","; - analysis.reasoning.mode = reasoning_mode::TAG_BASED; - analysis.reasoning.start = "<|channel>thought"; - analysis.reasoning.end = ""; - analysis.preserved_tokens.clear(); - analysis.preserved_tokens.push_back("<|tool_call>"); - analysis.preserved_tokens.push_back(""); - analysis.preserved_tokens.push_back("<|tool_response>"); - analysis.preserved_tokens.push_back(""); - analysis.preserved_tokens.push_back("<|\"|>"); - analysis.preserved_tokens.push_back("<|turn>"); - LOG_DBG(ANSI_ORANGE "[Patch: Gemma4]\n" ANSI_RESET); - } - }, // DeepSeek-R1-Distill-Qwen [](const common_chat_template & tmpl, autoparser & analysis) -> void { if (tmpl.src.find( diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index f2ed77c440..624dee22fb 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -75,84 +75,6 @@ static std::string escape_json_string_inner(const std::string & s) { return escaped; } -static const std::string GEMMA4_QUOTE = "<|\"|>"; - -static std::string normalize_gemma4_to_json(const std::string & input) { - std::string result; - result.reserve(input.size() * 2); - - enum Ctx { DICT, ARRAY }; - std::vector ctx; - - auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; }; - auto skip_ws = [&](size_t & pos) { - while (pos < input.size() && is_ws(input[pos])) { - result += input[pos++]; - } - }; - - auto quote_unquoted_key = [&](size_t & pos) { - if (pos < input.size() && input[pos] != '"' && input[pos] != '}') { - result += '"'; - while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) { - result += input[pos++]; - } - result += '"'; - skip_ws(pos); - } - }; - - size_t i = 0; - while (i < input.size()) { - if (i + GEMMA4_QUOTE.size() <= input.size() && - input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) { - result += '"'; - i += GEMMA4_QUOTE.size(); - continue; - } - - char c = input[i]; - - if (c == '{') { - result += c; - ctx.push_back(DICT); - ++i; - skip_ws(i); - quote_unquoted_key(i); - continue; - } - if (c == '}') { - result += c; - if (!ctx.empty()) ctx.pop_back(); - ++i; - continue; - } - if (c == '[') { - result += c; - ctx.push_back(ARRAY); - ++i; - continue; - } - if (c == ']') { - result += c; - if (!ctx.empty()) ctx.pop_back(); - ++i; - continue; - } - if (c == ',' && !ctx.empty() && ctx.back() == DICT) { - result += c; - ++i; - skip_ws(i); - quote_unquoted_key(i); - continue; - } - - result += c; - ++i; - } - return result; -} - // Convert Python-style single-quoted strings to JSON double-quoted strings // Only converts outer string delimiters, properly handling escape sequences: // - {'key': 'value'} -> {"key": "value"} @@ -296,10 +218,6 @@ std::string common_chat_peg_mapper::normalize_container_value(const std::string return normalize_quotes_to_json(input); } -std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) { - return normalize_quotes_to_json(normalize_gemma4_to_json(input)); -} - void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & parse_result_arg) { arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); }); @@ -947,3 +865,143 @@ common_peg_parser common_chat_peg_builder::standard_json_tools( return force_tool_calls ? section : optional(section); } + +void common_chat_peg_gemma4_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { + for (const auto & node : result.nodes) { + visit(arena, node); + } +} + +static std::string gemma4_to_json(const common_peg_ast_arena & arena, common_peg_ast_id id) { + const auto & node = arena.get(id); + + if (node.text.empty()) { + return ""; + } + + if (node.rule == "gemma4-number" || node.rule == "gemma4-bool" || node.rule == "gemma4-null") { + return std::string(node.text); + } + + if (node.rule == "gemma4-string-content") { + return escape_json_string_inner(std::string(node.text)); + } + + if (node.rule == "gemma4-string") { + std::string result = "\""; + if (!node.children.empty()) { + result += gemma4_to_json(arena, node.children[0]); + if (!node.is_partial) { + result += "\""; + } + } + return result; + } + + if (node.rule == "gemma4-array") { + std::string result = "["; + + bool add_comma = false; + for (auto child_id : node.children) { + if (add_comma) { + result += ','; + } + add_comma = true; + result += gemma4_to_json(arena, child_id); + } + + if (!node.is_partial) { + result += ']'; + } + return result; + } + + if (node.rule == "gemma4-dict-key-name") { + return std::string(node.text); + } + + if (node.rule == "gemma4-dict-key") { + std::string result = "\""; + if (!node.children.empty()) { + result += escape_json_string_inner(gemma4_to_json(arena, node.children[0])); + } + if (!node.is_partial) { + result += "\":"; + } + return result; + } + + if (node.rule == "gemma4-dict-kv") { + std::string result; + for (auto child_id : node.children) { + result += gemma4_to_json(arena, child_id); + } + return result; + } + + if (node.rule == "gemma4-dict") { + std::string result = "{"; + + bool add_comma = false; + for (auto child_id : node.children) { + if (add_comma) { + result += ','; + } + add_comma = true; + result += gemma4_to_json(arena, child_id); + } + + if (!node.is_partial) { + result += '}'; + } + return result; + } + + if (node.rule == "gemma4-value") { + if (!node.children.empty()) { + return gemma4_to_json(arena, node.children[0]); + } + return ""; + } + + return ""; +} + +void common_chat_peg_gemma4_mapper::visit(const common_peg_ast_arena & arena, common_peg_ast_id id) { + const auto & node = arena.get(id); + + if (node.tag == "reasoning") { + result.reasoning_content += std::string(node.text); + return; + } + + if (node.tag == "content") { + result.content += std::string(node.text); + return; + } + + if (node.tag == "tool") { + auto name_id = arena.find_by_tag(node, "tool-name"); + auto args_id = arena.find_by_tag(node, "tool-args"); + + if (name_id != COMMON_PEG_INVALID_AST_ID && args_id != COMMON_PEG_INVALID_AST_ID) { + const auto & name_node = arena.get(name_id); + const auto & args_node = arena.get(args_id); + + if (!name_node.is_partial) { + common_chat_tool_call call; + call.name = std::string(name_node.text); + if (!args_node.children.empty()) { + call.arguments = gemma4_to_json(arena, args_node.children[0]); + } + result.tool_calls.push_back(call); + } + } + + return; + } + + for (auto child_id : node.children) { + visit(arena, child_id); + } +} diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index dd1388ec14..1ea3eb7eb8 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -35,8 +35,9 @@ class common_chat_peg_mapper { class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper { public: common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} - protected: - std::string normalize_container_value(const std::string & input) override; + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); + private: + void visit(const common_peg_ast_arena & arena, common_peg_ast_id id); }; struct content_structure; diff --git a/common/chat.cpp b/common/chat.cpp index e93ee6b230..5b93c58873 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1077,6 +1077,131 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } +static common_chat_params common_chat_params_init_gemma4(const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + common_chat_params data; + + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; + data.supports_thinking = true; + + data.preserved_tokens = { + "<|channel>", + "", + "<|tool_call>", + "", + "<|turn>", + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); + auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>")); + + if (extract_reasoning) { + p.rule("thought", p.literal("<|channel>thought\n") + p.reasoning(p.until("")) + p.literal("")); + } else { + p.rule("thought", p.content(p.literal("<|channel>thought\n") + p.until("") + p.literal(""))); + } + + auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>")); + + if (has_response_format) { + auto response_format = p.literal("```json") << + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) << + p.literal("```"); + return start + p.optional(thought) + response_format; + } + + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + // Gemma4 tool calling syntax + // Rules should match traversal logic in gemma4_to_json() + p.rule("gemma4-string-content", p.until("<|\"|>")); + p.rule("gemma4-string", p.literal("<|\"|>") + p.ref("gemma4-string-content") + p.literal("<|\"|>")); + p.rule("gemma4-bool", p.json_bool()); + p.rule("gemma4-null", p.json_null()); + p.rule("gemma4-number", p.json_number()); + p.rule("gemma4-dict-key", p.rule("gemma4-dict-key-name", p.until(":")) + p.literal(":")); + p.rule("gemma4-dict-kv", p.ref("gemma4-dict-key") + p.space() + p.ref("gemma4-value")); + p.rule("gemma4-dict", [&]() { + auto ws = p.space(); + auto member = p.ref("gemma4-dict-kv"); + auto members = p.sequence({member, p.zero_or_more(p.sequence({p.literal(","), ws, member}))}); + return p.sequence({ + p.literal("{"), ws, + p.choice({p.literal("}"), p.sequence({members, ws, p.literal("}")})}) + }); + }); + p.rule("gemma4-array", [&]() { + auto ws = p.space(); + auto value = p.ref("gemma4-value"); + auto elements = p.sequence({value, p.zero_or_more(p.sequence({p.literal(","), ws, value}))}); + return p.sequence({ + p.literal("["), ws, + p.choice({p.literal("]"), p.sequence({elements, ws, p.literal("]")})}) + }); + }); + p.rule("gemma4-value", [&]() { + return p.choice({ + p.ref("gemma4-string"), p.ref("gemma4-dict"), p.ref("gemma4-array"), + p.ref("gemma4-number"), p.ref("gemma4-bool"), p.ref("gemma4-null") + }); + }); + + auto tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + // TODO @aldehir : need to extend json-schema-to-grammar to produce more than JSON rules + // const auto & params = function.at("parameters"); + + tool_choice |= p.rule("tool-" + name, p.tool(p.sequence({ + p.tool_open(p.tool_name(p.literal(name)) + p.peek(p.literal("{"))), + p.tool_args(p.ref("gemma4-dict")), + }))); + }); + + auto tool_call = p.trigger_rule("tool-call", p.repeat( + "<|tool_call>call:" + tool_choice + "", + /* min = */ inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0, + /* max = */ inputs.parallel_tool_calls ? -1 : 1 + )); + + auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"}))); + auto message = p.rule("message", thought + content); + return start + p.zero_or_more(message) + tool_call; + } + + auto content = p.rule("content", p.content(p.until("<|channel>"))); + auto message = p.rule("message", thought + content); + return start + p.one_or_more(message); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED)); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call>" }, + }; + } + + return data; +} + // Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { @@ -1556,46 +1681,146 @@ static void requires_non_null_content(json & messages) { } // Gemma4 uses a custom tool_responses field instead of role:tool messages. -// Convert consecutive role:tool messages into a single user message with tool_responses. +// +// This will transform a sequence of messages: +// assistant(tool_call+) -> tool+ -> assistant(content) +// +// Into a single assistant message containing a tool_responses field: +// assistant(content + tool_call + tool_responses) +// +// This is necessary for the Gemma4 chat template to properly format the prompt. +// See https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 +struct gemma4_model_turn_builder { + json & messages; + size_t pos; + json tool_calls = json::array(); + json tool_responses = json::array(); + json content; + json reasoning_content; + + gemma4_model_turn_builder(json & msgs, size_t pos) : messages(msgs), pos(pos) {} + + void collect() { + // Collect the first assistant message + auto & msg = messages[pos]; + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { + // According to the prompt formatting guide, we need to preserve reasoning_content + // between function calls. The current chat templates do not support this, but we will do it anyway. + reasoning_content = msg.at("reasoning_content"); + } + for (auto & tc : msg.at("tool_calls")) { + tool_calls.push_back(tc); + } + pos++; + + // Collect tool call results + while (pos < messages.size() && messages[pos].value("role", "") == "tool") { + collect_result(messages[pos]); + pos++; + } + + // Check if the next assistant message is the final message + if (pos < messages.size() && messages[pos].value("role", "") == "assistant") { + auto & next = messages[pos]; + if (!has_tool_calls(next) && has_content(next)) { + content = next.at("content"); + pos++; + } + } + } + + void collect_result(const json & curr) { + json response; + if (curr.contains("content")) { + const auto & content = curr.at("content"); + if (content.is_string()) { + // Try to parse the content as JSON; fall back to raw string + try { + response = json::parse(content.get()); + } catch (...) { + response = content; + } + } else { + response = content; + } + } + + std::string name; + + // Match name with corresponding tool call + size_t idx = tool_responses.size(); + if (idx < tool_calls.size()) { + auto & tc = tool_calls[idx]; + if (tc.contains("function")) { + name = tc.at("function").value("name", ""); + } + } + + // Fallback to the tool call id + if (name.empty()) { + name = curr.value("tool_call_id", ""); + } + + tool_responses.push_back({{"name", name}, {"response", response}}); + } + + json build() { + collect(); + + json msg = { + {"role", "assistant"}, + {"tool_calls", tool_calls}, + }; + if (!tool_responses.empty()) { + msg["tool_responses"] = tool_responses; + } + if (!content.is_null()) { + msg["content"] = content; + } + if (!reasoning_content.is_null()) { + msg["reasoning_content"] = reasoning_content; + } + return msg; + } + + static bool has_content(const json & msg) { + if (!msg.contains("content") || msg.at("content").is_null()) { + return false; + } + const auto & content = msg.at("content"); + if (content.is_string() && !content.get().empty()) { + return true; + } + if (content.is_array() && !content.empty()) { + return true; + } + return false; + } + + static bool has_tool_calls(const json & msg) { + return msg.contains("tool_calls") && msg.at("tool_calls").is_array() && !msg.at("tool_calls").empty(); + } +}; + static void convert_tool_responses_gemma4(json & messages) { json result = json::array(); size_t i = 0; + while (i < messages.size()) { - if (messages[i].contains("role") && messages[i].at("role") == "tool") { - json tool_responses = json::array(); - while (i < messages.size() && - messages[i].contains("role") && - messages[i].at("role") == "tool") { - const auto & tool_msg = messages[i]; - std::string name; - if (tool_msg.contains("tool_call_id") && tool_msg.at("tool_call_id").is_string()) { - name = tool_msg.at("tool_call_id"); - } else if (tool_msg.contains("name") && tool_msg.at("name").is_string()) { - name = tool_msg.at("name"); - } - json response; - if (tool_msg.contains("content")) { - const auto & content = tool_msg.at("content"); - if (content.is_string()) { - // Try to parse the content as JSON; fall back to raw string - try { - response = json::parse(content.get()); - } catch (...) { - response = content; - } - } else { - response = content; - } - } - tool_responses.push_back({{"name", name}, {"response", response}}); - i++; - } - result.push_back({{"role", "user"}, {"tool_responses", tool_responses}}); - } else { - result.push_back(messages[i]); + auto & msg = messages[i]; + + if (msg.value("role", "") != "assistant" || !msg.contains("tool_calls") || + !msg.at("tool_calls").is_array() || msg.at("tool_calls").empty()) { + result.push_back(msg); i++; + continue; } + + gemma4_model_turn_builder builder(messages, i); + result.push_back(builder.build()); + i = builder.pos; } + messages = result; } @@ -1634,7 +1859,7 @@ static json common_chat_extra_context() { std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, - const autoparser::generation_params & params) { + autoparser::generation_params & params) { // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && @@ -1687,6 +1912,12 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_gigachat_v3(tmpl, params); } + // Gemma4 format detection + if (src.find("'<|tool_call>call:'") != std::string::npos) { + workaround::convert_tool_responses_gemma4(params.messages); + return common_chat_params_init_gemma4(tmpl, params); + } + return std::nullopt; } @@ -1727,10 +1958,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } - if (src.find("'<|tool_call>call:'") != std::string::npos) { - workaround::convert_tool_responses_gemma4(params.messages); - } - params.add_generation_prompt = false; std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); params.add_generation_prompt = true; diff --git a/common/chat.h b/common/chat.h index d5328379cc..b06ca37fd7 100644 --- a/common/chat.h +++ b/common/chat.h @@ -274,4 +274,4 @@ std::string common_chat_template_direct_apply( std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, - const autoparser::generation_params & params); + autoparser::generation_params & params); diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 86faacd61f..59fa4c5c55 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -256,6 +256,38 @@ static std::pair, bool> parse_c return {ranges, negated}; } +common_peg_ast_id common_peg_ast_arena::find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth) const { + for (auto child_id : parent.children) { + const auto & child = get(child_id); + if (child.tag == tag) { + return child_id; + } + if (max_depth > 1) { + auto result = find_by_tag(child, tag, max_depth - 1); + if (result != COMMON_PEG_INVALID_AST_ID) { + return result; + } + } + } + return COMMON_PEG_INVALID_AST_ID; +} + +common_peg_ast_id common_peg_ast_arena::find_by_rule(const common_peg_ast_node & parent, const std::string & rule, int max_depth) const { + for (auto child_id : parent.children) { + const auto & child = get(child_id); + if (child.rule == rule) { + return child_id; + } + if (max_depth > 1) { + auto result = find_by_rule(child, rule, max_depth - 1); + if (result != COMMON_PEG_INVALID_AST_ID) { + return result; + } + } + } + return COMMON_PEG_INVALID_AST_ID; +} + void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const { if (id == COMMON_PEG_INVALID_AST_ID) { return; diff --git a/common/peg-parser.h b/common/peg-parser.h index 31cdf9ec2d..f242fc4211 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -106,6 +106,9 @@ class common_peg_ast_arena { const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); } + common_peg_ast_id find_by_tag(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const; + common_peg_ast_id find_by_rule(const common_peg_ast_node & parent, const std::string & tag, int max_depth = 3) const; + size_t size() const { return nodes_.size(); } void clear() { nodes_.clear(); } diff --git a/models/templates/google-gemma-4-31B-it-interleaved.jinja b/models/templates/google-gemma-4-31B-it-interleaved.jinja new file mode 100644 index 0000000000..422f6da2b3 --- /dev/null +++ b/models/templates/google-gemma-4-31B-it-interleaved.jinja @@ -0,0 +1,282 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'OBJECT' -%} + ,properties:{ + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + {%- elif value is mapping -%} + {{- format_parameters(value, value['required'] | default([])) -}} + {%- endif -%} + } + {%- if value['required'] -%} + ,required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + ,items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None, last_user_message=-1) -%} +{%- set loop_messages = messages -%} +{{ bos_token }} +{#- Handle System/Tool Definitions Block -#} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + + {#- Inject Thinking token at the very top of the FIRST system turn -#} + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{- messages[0]['content'] | trim -}} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + + {{- '\n' -}} +{%- endif %} + +{#- Find last user message -#} +{%- for message in loop_messages -%} + {%- if message['role'] == 'user' -%} + {%- set ns.last_user_message = loop.index0 -%} + {%- endif -%} +{%- endfor -%} + +{#- Loop through messages -#} +{%- for message in loop_messages -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {%- if not (ns.prev_message_type == 'tool_response' and message['tool_calls']) -%} + {{- '<|turn>' + role + '\n' }} + {%- endif -%} + + {%- set ns.prev_message_type = None -%} + + {%- if message['tool_calls'] -%} + {#- Preserve reasoning between tool calls for model turns that come after the last user turn -#} + {%- if message['reasoning_content'] and loop.index0 > ns.last_user_message -%} + {{- '<|channel>thought\n' -}} + {{- message['reasoning_content'] -}} + {{- '' -}} + {%- endif -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- if message['tool_responses'] -%} + {#- Tool Response handling -#} + {%- for tool_response in message['tool_responses'] -%} + {{- '<|tool_response>' -}} + {%- if tool_response['response'] is mapping -%} + {{- 'response:' + tool_response['name'] | default('unknown') + '{' -}} + {%- for key, value in tool_response['response'] | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '\n\n<|image|>\n\n' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '\n\n<|video|>\n\n' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if not (message['tool_responses'] and not message['content']) -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- '<|turn>model\n' -}} + {%- endif -%} + {%- if not enable_thinking | default(false) -%} + {{- '<|channel>thought\n' -}} + {%- endif -%} +{%- endif -%} diff --git a/models/templates/gemma4.jinja b/models/templates/google-gemma-4-31B-it.jinja similarity index 100% rename from models/templates/gemma4.jinja rename to models/templates/google-gemma-4-31B-it.jinja diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index cbd361b4b9..cb55b46b72 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2551,6 +2551,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" || t.first == "" // smoldocling || t.first == "" // gemma4 + || t.first == "<|tool_response>" // gemma4 || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index d42bc8a102..72deeeab3c 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1976,10 +1976,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { // Google Gemma 4 (tool calling with Gemma4 dict format) - auto tst = peg_tester("models/templates/gemma4.jinja"); + auto tst = peg_tester("models/templates/google-gemma-4-31B-it.jinja"); tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run(); + // Reasoning and content + tst.test( + "<|channel>thought\nI'm\nthinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + // Reasoning and content with reasoning_format = none + tst.test( + "<|channel>thought\nI'm\nthinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_NONE) + .expect_content("<|channel>thought\nI'm\nthinkingHello, world!\nWhat's up?") + .run(); + // Simple tool call with string argument tst.test( "<|tool_call>call:get_time{city:<|\"|>London<|\"|>}")