diff --git a/common/sampling.cpp b/common/sampling.cpp index c66f935c65..68e36e8744 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; - std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); + trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - patterns_anywhere.push_back(trigger.value); + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { - trigger_patterns.push_back(trigger.value); + const auto & pattern = trigger.value; + std::string anchored = "^$"; + if (!pattern.empty()) { + anchored = (pattern.front() != '^' ? "^" : "") + + pattern + + (pattern.back() != '$' ? "$" : ""); + } + trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 75d5d750c3..d62733b5d6 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -369,6 +369,54 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// +static llama_grammar_trigger_pattern llama_grammar_trigger_pattern_compile(const std::string & pattern) { + llama_grammar_trigger_pattern_type type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH; + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // If anchored on both ends, consider it a match + type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH; + } + return {type, pattern, std::regex(pattern)}; +} + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH) { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH) { + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + } + + return std::string::npos; +} + + // // implementation // @@ -1192,9 +1240,7 @@ struct llama_grammar * llama_grammar_init_impl( } for (size_t i = 0; i < num_trigger_patterns; i++) { GGML_ASSERT(trigger_patterns != nullptr); - auto & trigger = vec_trigger_patterns.emplace_back(); - trigger.pattern = trigger_patterns[i]; - trigger.regex = std::regex(trigger.pattern); + vec_trigger_patterns.emplace_back(llama_grammar_trigger_pattern_compile(trigger_patterns[i])); } // Important: vec_rules has to be moved here, not copied, because stacks contains @@ -1312,21 +1358,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; - std::smatch match; for (const auto & trigger_pattern : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { grammar.awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; - } - } - if (start == std::string::npos) { - start = match.position(0); - } // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { diff --git a/src/llama-grammar.h b/src/llama-grammar.h index a4c978ac11..2cd03bff1f 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -116,9 +116,17 @@ struct llama_grammar_parser { void print(FILE * file); }; +enum llama_grammar_trigger_pattern_type { + LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH = 0, + LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH = 1, +}; + struct llama_grammar_trigger_pattern { + llama_grammar_trigger_pattern_type type; std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; struct llama_grammar {