diff --git a/common/chat.cpp b/common/chat.cpp index 5ccb238f20..df13e0db09 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -980,15 +980,19 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis")); auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1); + // Occasionally, gpt-oss-20b will prefix channels with this commentary + auto stray_commentary = p.optional(p.literal("<|channel|>commentary") + p.optional(p.literal(" to=assistant"))); + auto start_analysis = stray_commentary + p.literal("<|channel|>analysis<|message|>"); + if (extract_reasoning) { - p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end); + p.rule("analysis", start_analysis + p.reasoning(content) + end); } else { - p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end)); + p.rule("analysis", p.content(start_analysis + content + end)); } auto analysis = p.ref("analysis"); auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end); - auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content)); + auto final_msg = p.rule("final", stray_commentary + p.literal("<|channel|>final<|message|>") + p.content(content)); // Consume any unsolicited tool calls, e.g. builtin functions auto unsolicited = p.rule("unsolicited", p.atomic(p.optional(channel) + p.literal(" to=") + content + end)); @@ -996,7 +1000,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp auto any = p.rule("any", preamble | analysis); if (has_response_format) { - auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type); + auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type); auto response_format = p.rule("response-format", p.literal("<|channel|>final") + constraint + p.literal("<|message|>") + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); @@ -1013,7 +1017,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp const auto & params = function.at("parameters"); auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name)); - auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type); + auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type); auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params)); // recipient in role header @@ -1054,6 +1058,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.grammar_triggers = { { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^<\\|channel\\|>(?:commentary|analysis)\\s+to=functions$" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" } }; diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index a6d9a4c27c..694f9b850a 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1557,6 +1557,36 @@ static std::unordered_set collect_reachable_rules( // GBNF generation implementation void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const { + auto schema_delegates = [](const common_peg_schema_parser & s) -> bool { + if (!s.schema) { + return true; + } + if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") { + return true; + } + return false; + }; + + // Unwrap the parser so we can properly check if it's a sequence or choice + auto effective_parser = [&](common_peg_parser_id id) -> const common_peg_parser_variant & { + while (true) { + const auto & p = parsers_.at(id); + if (const auto * tag = std::get_if(&p)) { + id = tag->child; + } else if (const auto * atomic = std::get_if(&p)) { + id = atomic->child; + } else if (const auto * schema = std::get_if(&p)) { + if (schema_delegates(*schema)) { + id = schema->child; + } else { + return p; + } + } else { + return p; + } + } + }; + // Generate GBNF for a parser std::function to_gbnf = [&](common_peg_parser_id id) -> std::string { const auto & parser = parsers_.at(id); @@ -1577,7 +1607,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo s += " "; } auto child_gbnf = to_gbnf(child); - const auto & child_parser = parsers_.at(child); + const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { s += "(" + child_gbnf + ")"; @@ -1593,7 +1623,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo s += " | "; } auto child_gbnf = to_gbnf(child); - const auto & child_parser = parsers_.at(child); + const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser)) { s += "(" + child_gbnf + ")"; } else { @@ -1603,7 +1633,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return s; } else if constexpr (std::is_same_v) { auto child_gbnf = to_gbnf(p.child); - const auto & child_parser = parsers_.at(p.child); + const auto & child_parser = effective_parser(p.child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { child_gbnf = "(" + child_gbnf + ")"; @@ -1663,15 +1693,10 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo } return gbnf_excluding_pattern(p.delimiters); } else if constexpr (std::is_same_v) { - if (p.schema) { - if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") { - // TODO: Implement more comprehensive grammar generation for raw strings. - // For now, use the grammar emitted from the underlying parser. - return to_gbnf(p.child); - } - return builder.add_schema(p.name, *p.schema); + if (schema_delegates(p)) { + return to_gbnf(p.child); } - return to_gbnf(p.child); + return builder.add_schema(p.name, *p.schema); } else if constexpr (std::is_same_v) { return p.name; } else if constexpr (std::is_same_v) { diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 68857a5e88..1ab9a7ede3 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -213,6 +213,66 @@ void test_gbnf_generation(testing &t) { )""", gbnf); }); + t.test("tagged choice inside sequence gets parenthesized", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("a") + p.tag("t", p.literal("b") | p.literal("c")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a" ("b" | "c") + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("tagged sequence inside choice gets parenthesized", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.tag("t", p.literal("a") + p.literal("b")) | p.literal("c"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a" "b" | "c" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("atomic choice inside repetition gets parenthesized", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.atomic(p.literal("a") | p.literal("b"))); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ("a" | "b")+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("nested transparent wrappers get parenthesized", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b"))); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "x" ("a" | "b") + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + t.test("emit only trigger rules (and references)", [](testing &t) { auto parser = build_peg_parser([](common_peg_parser_builder & p) { auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2")); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 5a90590b77..a716c35eb4 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -3175,6 +3175,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reasoning("I will execute python to say hello") .expect_content("") .run(); + + // Edge cases + + // "<|channel|>commentary to=assistant" before reasoning + tst.test( + "<|channel|>commentary to=assistant<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's " + "up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + // "<|channel|>commentary to=assistant" before final message + tst.test( + "<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>commentary to=assistant<|channel|>final<|message|>Hello, world!\nWhat's " + "up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); } {