grammar : check pattern directly instead of adding a type
This commit is contained in:
parent
f907124097
commit
3bdcc4f773
|
|
@ -372,14 +372,6 @@ static void print_rule(
|
||||||
//
|
//
|
||||||
// Regex utilities
|
// Regex utilities
|
||||||
//
|
//
|
||||||
static llama_grammar_trigger_pattern llama_grammar_trigger_pattern_compile(const std::string & pattern) {
|
|
||||||
llama_grammar_trigger_pattern_type type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH;
|
|
||||||
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
|
||||||
// If anchored on both ends, consider it a match
|
|
||||||
type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH;
|
|
||||||
}
|
|
||||||
return {type, pattern, std::regex(pattern)};
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||||
auto find_start_pos = [](const std::smatch & match) {
|
auto find_start_pos = [](const std::smatch & match) {
|
||||||
|
|
@ -397,7 +389,7 @@ size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||||
return start;
|
return start;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH) {
|
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
||||||
// match against the entire input
|
// match against the entire input
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
if (std::regex_match(input, match, regex)) {
|
if (std::regex_match(input, match, regex)) {
|
||||||
|
|
@ -405,12 +397,10 @@ size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH) {
|
// search anywhere
|
||||||
// search anywhere
|
std::smatch match;
|
||||||
std::smatch match;
|
if (std::regex_search(input, match, regex)) {
|
||||||
if (std::regex_search(input, match, regex)) {
|
return find_start_pos(match);
|
||||||
return find_start_pos(match);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::string::npos;
|
return std::string::npos;
|
||||||
|
|
@ -1240,7 +1230,9 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
||||||
GGML_ASSERT(trigger_patterns != nullptr);
|
GGML_ASSERT(trigger_patterns != nullptr);
|
||||||
vec_trigger_patterns.emplace_back(llama_grammar_trigger_pattern_compile(trigger_patterns[i]));
|
auto & trigger = vec_trigger_patterns.emplace_back();
|
||||||
|
trigger.pattern = trigger_patterns[i];
|
||||||
|
trigger.regex = std::regex(trigger.pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
|
|
||||||
|
|
@ -116,13 +116,7 @@ struct llama_grammar_parser {
|
||||||
void print(FILE * file);
|
void print(FILE * file);
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_grammar_trigger_pattern_type {
|
|
||||||
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH = 0,
|
|
||||||
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar_trigger_pattern {
|
struct llama_grammar_trigger_pattern {
|
||||||
llama_grammar_trigger_pattern_type type;
|
|
||||||
std::string pattern;
|
std::string pattern;
|
||||||
std::regex regex;
|
std::regex regex;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue