diff --git a/common/chat.cpp b/common/chat.cpp
index b98ab21ce1..22e527bab8 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -2065,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
- "<\\|channel\\|>(commentary|analysis) to"
+ "<\\|channel\\|>(?:commentary|analysis) to"
});
// Trigger tool calls that appear in the role section, either at the
@@ -2398,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
data.grammar_triggers.push_back({
- COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
// If thinking_forced_open, then we capture the tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
- std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + (
+ std::string(data.thinking_forced_open ? "(\\s*)" : "") + (
"\\s*("
"(?:"
"||||)?"
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
")"
- ")[\\s\\S]*"
+ ")"
),
});
data.preserved_tokens = {
diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp
index 4bff6b6633..e667a209e9 100644
--- a/common/regex-partial.cpp
+++ b/common/regex-partial.cpp
@@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
return res;
}
std::match_results srmatch;
- if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
+ if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
@@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- - /a|b/ -> (a|b).*
+ - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
+ - /a|b/ -> ^(a|b)
- /a*?/ -> error, could match ""
- - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- - /.*?ab/ -> ((?:b)?a).* (merge .*)
- - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
+ - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
+ - /.*?ab/ -> ^((?:b)?a) (omit .*)
+ - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
+ - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
+ - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
+ - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
- The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
- (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
+ The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
+ All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
@@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
}
}
- // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
+ // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector res_alts;
@@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
throw std::runtime_error("Unmatched '(' in pattern");
}
- return "(" + res + ")[\\s\\S]*";
+ return "^(" + res + ")";
}
diff --git a/common/sampling.cpp b/common/sampling.cpp
index c66f935c65..68e36e8744 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector trigger_patterns;
- std::vector patterns_anywhere;
std::vector trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
- patterns_anywhere.push_back(regex_escape(word));
+ trigger_patterns.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
- patterns_anywhere.push_back(trigger.value);
+ trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
- trigger_patterns.push_back(trigger.value);
+ const auto & pattern = trigger.value;
+ std::string anchored = "^$";
+ if (!pattern.empty()) {
+ anchored = (pattern.front() != '^' ? "^" : "")
+ + pattern
+ + (pattern.back() != '$' ? "$" : "");
+ }
+ trigger_patterns.push_back(anchored);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
- if (!patterns_anywhere.empty()) {
- trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
- }
-
std::vector trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
index 75d5d750c3..64ea2fd00a 100644
--- a/src/llama-grammar.cpp
+++ b/src/llama-grammar.cpp
@@ -369,6 +369,44 @@ static void print_rule(
fprintf(file, "\n");
}
+//
+// Regex utilities
+//
+
+size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
+ auto find_start_pos = [](const std::smatch & match) {
+ // get from the first matched capturing group to the end of the string
+ size_t start = std::string::npos;
+ for (auto i = 1u; i < match.size(); i++) {
+ if (match.length(i) > 0) {
+ start = match.position(i);
+ break;
+ }
+ }
+ if (start == std::string::npos) {
+ start = match.position(0);
+ }
+ return start;
+ };
+
+ if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
+ // match against the entire input
+ std::smatch match;
+ if (std::regex_match(input, match, regex)) {
+ return find_start_pos(match);
+ }
+ }
+
+ // search anywhere
+ std::smatch match;
+ if (std::regex_search(input, match, regex)) {
+ return find_start_pos(match);
+ }
+
+ return std::string::npos;
+}
+
+
//
// implementation
//
@@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
- std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
- if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
+ auto start = trigger_pattern.find(grammar.trigger_buffer);
+ if (start != std::string::npos) {
grammar.awaiting_trigger = false;
- // get from the first matched capturing group to the end of the string
- size_t start = std::string::npos;
- for (auto i = 1u; i < match.size(); i++) {
- if (match.length(i) > 0) {
- start = match.position(i);
- break;
- }
- }
- if (start == std::string::npos) {
- start = match.position(0);
- }
// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
index a4c978ac11..b5a0e588e9 100644
--- a/src/llama-grammar.h
+++ b/src/llama-grammar.h
@@ -119,6 +119,8 @@ struct llama_grammar_parser {
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
+
+ size_t find(const std::string & input) const;
};
struct llama_grammar {
diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp
index ffad189786..70af6d75a1 100644
--- a/tests/test-regex-partial.cpp
+++ b/tests/test-regex-partial.cpp
@@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals(
- "((?:(?:c)?b)?a)[\\s\\S]*",
+ "^((?:(?:c)?b)?a)",
regex_to_reversed_partial_regex("abc"));
assert_equals(
- "(a+)[\\s\\S]*",
+ "^(a+)",
regex_to_reversed_partial_regex("a+"));
assert_equals(
- "(a*)[\\s\\S]*",
+ "^(a*)",
regex_to_reversed_partial_regex("a*"));
assert_equals(
- "(a?)[\\s\\S]*",
+ "^(a?)",
regex_to_reversed_partial_regex("a?"));
assert_equals(
- "([a-z])[\\s\\S]*",
+ "^([a-z])",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals(
- "((?:\\w+)?[a-z])[\\s\\S]*",
+ "^((?:\\w+)?[a-z])",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals(
- "((?:a|b))[\\s\\S]*",
+ "^((?:a|b))",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals(
- "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+ "^((?:(?:(?:d)?c)?b)?a)",
regex_to_reversed_partial_regex("abcd"));
assert_equals(
- "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+ "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals(
- "((?:(?:b)?a)?.*)[\\s\\S]*",
+ "^((?:(?:b)?a)?.*)",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals(
- "((?:(?:b)?.*)?a)[\\s\\S]*",
+ "^((?:(?:b)?.*)?a)",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals(
- "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+ "^((?:(?:d)?(?:(?:c)?b))?a)",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals(
- "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+ "^((?:(?:(?:c)?b|(?:e)?d))?a)",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals(
- "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+ "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
regex_to_reversed_partial_regex("ab{2,4}c"));
}