diff --git a/common/chat.cpp b/common/chat.cpp index eae46df7d3..e27b6c3413 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1091,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + + if (inputs.add_generation_prompt && string_ends_with(data.prompt, "\n")) { + // This may happen if the model generates content + tool_call, the + // template does not add the model's next turn and confuses the model + // from emitting its proper reasoning token sequence. + data.prompt += "<|turn>model\n"; + } + data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; data.thinking_start_tag = "<|channel>thought"; @@ -1118,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("") + p.literal(""))); } - auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>")); + auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), ""); + auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>")); if (has_response_format) { auto response_format = p.literal("```json") << @@ -1182,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ /* max = */ inputs.parallel_tool_calls ? -1 : 1 )); - auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"}))); + auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>")); + auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "", "<|tool_call>"}))); auto message = p.rule("message", thought + content); - return start + p.zero_or_more(message) + tool_call; + return start + p.zero_or_more(message) + scan_to_toolcall + tool_call; } - auto content = p.rule("content", p.content(p.until("<|channel>"))); + // Gemma 4 may emit an extra <|channel>thought\n at the end of the content. It may + // also emit a single trailing token. Consume all complete reasoning blocks and + // then stop at the first unmatched token. + auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", ""}))); auto message = p.rule("message", thought + content); return start + p.one_or_more(message); }); diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 59fa4c5c55..e37c1ce80e 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -890,6 +890,10 @@ struct parser_executor { } return result; } + + common_peg_parse_result operator()(const common_peg_gbnf_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } }; common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { @@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { p.child = resolve_ref(p.child); } else if constexpr (std::is_same_v) { p.child = resolve_ref(p.child); @@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id return "Not(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Atomic(" + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -1565,6 +1572,7 @@ static std::unordered_set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { visit(p.child); } else if constexpr (std::is_same_v) { @@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo } else if constexpr (std::is_same_v) { std::string s; for (const auto & child : p.children) { + auto child_gbnf = to_gbnf(child); + if (child_gbnf.empty()) { + continue; + } if (!s.empty()) { s += " "; } - auto child_gbnf = to_gbnf(child); const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { @@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return to_gbnf(p.child); } else if constexpr (std::is_same_v) { return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return p.grammar; } else { static_assert(is_always_false_v); } @@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & {"child", p.child}, {"tag", p.tag} }; + } else if constexpr (std::is_same_v) { + return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}}; } }, variant); } @@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json }; } + if (type == "gbnf") { + if (!j.contains("child") || !j.contains("grammar")) { + throw std::runtime_error("gbnf parser missing required fields"); + } + return common_peg_gbnf_parser{ + j["child"].get(), + j["grammar"].get(), + }; + } + throw std::runtime_error("Unknown parser type: " + type); } diff --git a/common/peg-parser.h b/common/peg-parser.h index f242fc4211..b6bb05214b 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -270,6 +270,11 @@ struct common_peg_tag_parser { std::string tag; }; +struct common_peg_gbnf_parser { + common_peg_parser_id child; + std::string grammar; +}; + // Variant holding all parser types using common_peg_parser_variant = std::variant< common_peg_epsilon_parser, @@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant< common_peg_rule_parser, common_peg_ref_parser, common_peg_atomic_parser, - common_peg_tag_parser + common_peg_tag_parser, + common_peg_gbnf_parser >; class common_peg_arena { @@ -504,6 +510,10 @@ class common_peg_parser_builder { // Unlike rules, you can tag multiple nodes with the same tag. common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); } + // Wraps a child parser but emits a custom GBNF grammar string instead of + // the child's grammar. Parsing delegates entirely to the child. + common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); } + void set_root(const common_peg_parser & p); common_peg_arena build(); diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 1ab9a7ede3..fe4bbbdd16 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) { )""", gbnf); }); + t.test("silent parser emits nothing in gbnf", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.gbnf(p.literal("world"), ""); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("silent choice inside sequence emits nothing", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a" "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("silent wrapped in tag emits nothing", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), "")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("gbnf parser emits custom grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a" [a-z]+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + t.test("nested transparent wrappers get parenthesized", [](testing &t) { auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b"))); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 74e71ea53e..8438a5eaff 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ amount_tool }) .expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})")) .run(); + + // Edge cases + tst.test( + "<|channel>thought\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist) + .run(); + + tst.test( + "<|channel>thought\nHello, world!\nWhat's up?<|channel>thought\n") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist) + .run(); + + tst.test( + "<|channel>thought\nHello, world!\nWhat's up?<|channel>thought\n") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist) + .run(); + + tst.test( + "<|channel><|channel>thought\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist) + .run(); } {