diff --git a/common/chat.cpp b/common/chat.cpp index b98ab21ce1..22e527bab8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2065,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp // Trigger on tool calls that appear in the commentary channel data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(commentary|analysis) to" + "<\\|channel\\|>(?:commentary|analysis) to" }); // Trigger tool calls that appear in the role section, either at the @@ -2398,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, // If thinking_forced_open, then we capture the tag in the grammar, // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( "\\s*(" "(?:" "||||)?" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" ")" - ")[\\s\\S]*" + ")" ), }); data.preserved_tokens = { diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp index 4bff6b6633..e667a209e9 100644 --- a/common/regex-partial.cpp +++ b/common/regex-partial.cpp @@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b return res; } std::match_results srmatch; - if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) { auto group = srmatch[1].str(); if (group.length() != 0) { auto it = srmatch[1].second.base(); @@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b to see if a string ends with a partial regex match, but but it's not in std::regex yet. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. - - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* - - /a|b/ -> (a|b).* + - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a) + - /a|b/ -> ^(a|b) - /a*?/ -> error, could match "" - - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) - - /.*?ab/ -> ((?:b)?a).* (merge .*) - - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) - - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* - - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* - - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager) + - /.*?ab/ -> ^((?:b)?a) (omit .*) + - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches) + - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a) + - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a) + - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a) - The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern - (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern. + All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored. */ std::string regex_to_reversed_partial_regex(const std::string & pattern) { auto it = pattern.begin(); @@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { } } - // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a) // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // We'll do the outermost capturing group and final .* in the enclosing function. std::vector res_alts; @@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { throw std::runtime_error("Unmatched '(' in pattern"); } - return "(" + res + ")[\\s\\S]*"; + return "^(" + res + ")"; } diff --git a/common/sampling.cpp b/common/sampling.cpp index c66f935c65..68e36e8744 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; - std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); + trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - patterns_anywhere.push_back(trigger.value); + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { - trigger_patterns.push_back(trigger.value); + const auto & pattern = trigger.value; + std::string anchored = "^$"; + if (!pattern.empty()) { + anchored = (pattern.front() != '^' ? "^" : "") + + pattern + + (pattern.back() != '$' ? "$" : ""); + } + trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 75d5d750c3..64ea2fd00a 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -369,6 +369,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; - std::smatch match; for (const auto & trigger_pattern : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { grammar.awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; - } - } - if (start == std::string::npos) { - start = match.position(0); - } // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { diff --git a/src/llama-grammar.h b/src/llama-grammar.h index a4c978ac11..b5a0e588e9 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -119,6 +119,8 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; struct llama_grammar { diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp index ffad189786..70af6d75a1 100644 --- a/tests/test-regex-partial.cpp +++ b/tests/test-regex-partial.cpp @@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() { printf("[%s]\n", __func__); assert_equals( - "((?:(?:c)?b)?a)[\\s\\S]*", + "^((?:(?:c)?b)?a)", regex_to_reversed_partial_regex("abc")); assert_equals( - "(a+)[\\s\\S]*", + "^(a+)", regex_to_reversed_partial_regex("a+")); assert_equals( - "(a*)[\\s\\S]*", + "^(a*)", regex_to_reversed_partial_regex("a*")); assert_equals( - "(a?)[\\s\\S]*", + "^(a?)", regex_to_reversed_partial_regex("a?")); assert_equals( - "([a-z])[\\s\\S]*", + "^([a-z])", regex_to_reversed_partial_regex("[a-z]")); assert_equals( - "((?:\\w+)?[a-z])[\\s\\S]*", + "^((?:\\w+)?[a-z])", regex_to_reversed_partial_regex("[a-z]\\w+")); assert_equals( - "((?:a|b))[\\s\\S]*", + "^((?:a|b))", regex_to_reversed_partial_regex("(?:a|b)")); assert_equals( - "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*", + "^((?:(?:(?:d)?c)?b)?a)", regex_to_reversed_partial_regex("abcd")); assert_equals( - "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ?? + "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ?? regex_to_reversed_partial_regex("a*b")); assert_equals( - "((?:(?:b)?a)?.*)[\\s\\S]*", + "^((?:(?:b)?a)?.*)", regex_to_reversed_partial_regex(".*?ab")); assert_equals( - "((?:(?:b)?.*)?a)[\\s\\S]*", + "^((?:(?:b)?.*)?a)", regex_to_reversed_partial_regex("a.*?b")); assert_equals( - "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*", + "^((?:(?:d)?(?:(?:c)?b))?a)", regex_to_reversed_partial_regex("a(bc)d")); assert_equals( - "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*", + "^((?:(?:(?:c)?b|(?:e)?d))?a)", regex_to_reversed_partial_regex("a(bc|de)")); assert_equals( - "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*", + "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)", regex_to_reversed_partial_regex("ab{2,4}c")); }