grammar : add support for std::regex_search() with trigger patterns

This commit is contained in:
Alde Rojas 2025-12-24 01:16:49 -06:00
parent 5ee4e43f26
commit 8da07610f8
3 changed files with 69 additions and 24 deletions

View File

@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
trigger_patterns.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
patterns_anywhere.push_back(trigger.value);
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
const auto & pattern = trigger.value;
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {

View File

@ -369,6 +369,54 @@ static void print_rule(
fprintf(file, "\n");
}
//
// Regex utilities
//
static llama_grammar_trigger_pattern llama_grammar_trigger_pattern_compile(const std::string & pattern) {
llama_grammar_trigger_pattern_type type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH;
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
// If anchored on both ends, consider it a match
type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH;
}
return {type, pattern, std::regex(pattern)};
}
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
auto find_start_pos = [](const std::smatch & match) {
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
return start;
};
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH) {
// match against the entire input
std::smatch match;
if (std::regex_match(input, match, regex)) {
return find_start_pos(match);
}
}
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH) {
// search anywhere
std::smatch match;
if (std::regex_search(input, match, regex)) {
return find_start_pos(match);
}
}
return std::string::npos;
}
//
// implementation
//
@ -1192,9 +1240,7 @@ struct llama_grammar * llama_grammar_init_impl(
}
for (size_t i = 0; i < num_trigger_patterns; i++) {
GGML_ASSERT(trigger_patterns != nullptr);
auto & trigger = vec_trigger_patterns.emplace_back();
trigger.pattern = trigger_patterns[i];
trigger.regex = std::regex(trigger.pattern);
vec_trigger_patterns.emplace_back(llama_grammar_trigger_pattern_compile(trigger_patterns[i]));
}
// Important: vec_rules has to be moved here, not copied, because stacks contains
@ -1312,21 +1358,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
auto start = trigger_pattern.find(grammar.trigger_buffer);
if (start != std::string::npos) {
grammar.awaiting_trigger = false;
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {

View File

@ -116,9 +116,17 @@ struct llama_grammar_parser {
void print(FILE * file);
};
enum llama_grammar_trigger_pattern_type {
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH = 0,
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH = 1,
};
struct llama_grammar_trigger_pattern {
llama_grammar_trigger_pattern_type type;
std::string pattern;
std::regex regex;
size_t find(const std::string & input) const;
};
struct llama_grammar {