diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index b49ce22ef4..81c8255a11 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -1,9 +1,11 @@ #include "chat-diff-analyzer.h" #include "chat-auto-parser-helpers.h" +#include "chat-peg-parser.h" #include "chat.h" #include "log.h" #include "nlohmann/json.hpp" +#include "peg-parser.h" #include #include @@ -261,7 +263,7 @@ void analyze_reasoning::compare_reasoning_presence() { // prefix: ... auto suf_seg = prune_whitespace_segments(segmentize_markers(diff.suffix)); if (trim_whitespace(diff.left).empty() && suf_seg.size() >= 2 && suf_seg[0].type == segment_type::MARKER && - trim_whitespace(suf_seg[1].value).substr(0, 11) == "I can help.") { + trim_whitespace(suf_seg[1].value).find("I can help.") == 0) { auto pre_seg = prune_whitespace_segments(segmentize_markers(diff.prefix)); if (pre_seg[pre_seg.size() - 1].type == segment_type::MARKER || (pre_seg.size() > 1 && trim_whitespace(pre_seg[pre_seg.size() - 1].value).empty() && @@ -492,19 +494,26 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz bool found_plain_content = false; if (trim_whitespace(diff_tools.left) == response) { - auto segments = segmentize_markers(diff_reasoning.left); + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.space() + diff_reasoning.left + p.space() + p.optional(p.marker()) + p.space() + p.end(); + }); + if (parser.parse_and_extract(diff_reasoning.left).result.success()) { + // We only have the content text in the diff (possibly with a stray EOG marker), so no markers + mode = content_mode::PLAIN; + found_plain_content = true; + } + /* auto segments = segmentize_markers(diff_reasoning.left); if (trim_whitespace(diff_reasoning.left) == response || (segments.size() == 2 && trim_whitespace(segments[0].value) == response)) { // We only have the content text in the diff (possibly with a stray EOG marker), so no markers - mode = content_mode::PLAIN; + mode = content_mode::PLAIN; found_plain_content = true; - } else if (reasoning.mode != reasoning_mode::NONE && !reasoning.end.empty() && + }*/ else if (reasoning.mode != reasoning_mode::NONE && !reasoning.end.empty() && diff_reasoning.left.find(reasoning.end) != std::string::npos) { std::string post_closed_reasoning = diff_reasoning.left.substr( diff_reasoning.left.find(reasoning.end) + reasoning.end.length()); if (trim_whitespace(post_closed_reasoning) == "Response text") { - LOG_DBG("C1: No content markers after stripping reasoning close marker\n"); - mode = content_mode::PLAIN; + mode = content_mode::PLAIN; found_plain_content = true; } } diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index f1b10b21a5..1545a24210 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1424,6 +1424,12 @@ common_peg_parser common_peg_parser_builder::python_dict() { }); } +common_peg_parser common_peg_parser_builder::marker() { + auto sharp_bracket_parser = literal("<") + until(">") + literal(">"); + auto square_bracket_parser = literal("[") + until("]") + literal("]"); + return choice({ sharp_bracket_parser, square_bracket_parser }); +} + common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { auto ws = space(); return sequence({ diff --git a/common/peg-parser.h b/common/peg-parser.h index 947c775f10..b5e7ae13cf 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -456,6 +456,9 @@ class common_peg_parser_builder { common_peg_parser python_dict_bool(); common_peg_parser python_dict_null(); + // A marker, i.e. text delimited by a pair of <> or [] + common_peg_parser marker(); + // Wraps a parser with JSON schema metadata for grammar generation. // Used internally to convert JSON schemas to GBNF grammar rules. common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); diff --git a/tests/peg-parser/test-basic.cpp b/tests/peg-parser/test-basic.cpp index 1bda6f2e69..872f16a78d 100644 --- a/tests/peg-parser/test-basic.cpp +++ b/tests/peg-parser/test-basic.cpp @@ -1,3 +1,4 @@ +#include "peg-parser.h" #include "tests.h" void test_basic(testing & t) { @@ -450,5 +451,21 @@ void test_basic(testing & t) { t.assert_equal("result_is_fail", true, result.fail()); }); + + // Test markers + t.test("marker", [](testing &t) { + auto bracket_parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.marker(); + }); + + common_peg_parse_context ctx_square("[marker]", false); + common_peg_parse_context ctx_sharp("", false); + + auto result_square = bracket_parser.parse(ctx_square); + auto result_sharp = bracket_parser.parse(ctx_sharp); + + t.assert_true("result_square_is_success", result_square.success()); + t.assert_true("result_sharp_is_success", result_sharp.success()); + }); }); }