grammar : add support for std::regex_search() with trigger patterns
This commit is contained in:
parent
5ee4e43f26
commit
8da07610f8
|
|
@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
} else {
|
} else {
|
||||||
std::vector<std::string> trigger_patterns;
|
std::vector<std::string> trigger_patterns;
|
||||||
std::vector<std::string> patterns_anywhere;
|
|
||||||
std::vector<llama_token> trigger_tokens;
|
std::vector<llama_token> trigger_tokens;
|
||||||
for (const auto & trigger : params.grammar_triggers) {
|
for (const auto & trigger : params.grammar_triggers) {
|
||||||
switch (trigger.type) {
|
switch (trigger.type) {
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||||
{
|
{
|
||||||
const auto & word = trigger.value;
|
const auto & word = trigger.value;
|
||||||
patterns_anywhere.push_back(regex_escape(word));
|
trigger_patterns.push_back(regex_escape(word));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||||
{
|
{
|
||||||
patterns_anywhere.push_back(trigger.value);
|
trigger_patterns.push_back(trigger.value);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||||
{
|
{
|
||||||
trigger_patterns.push_back(trigger.value);
|
const auto & pattern = trigger.value;
|
||||||
|
std::string anchored = "^$";
|
||||||
|
if (!pattern.empty()) {
|
||||||
|
anchored = (pattern.front() != '^' ? "^" : "")
|
||||||
|
+ pattern
|
||||||
|
+ (pattern.back() != '$' ? "$" : "");
|
||||||
|
}
|
||||||
|
trigger_patterns.push_back(anchored);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||||
|
|
@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!patterns_anywhere.empty()) {
|
|
||||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const char *> trigger_patterns_c;
|
std::vector<const char *> trigger_patterns_c;
|
||||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||||
for (const auto & regex : trigger_patterns) {
|
for (const auto & regex : trigger_patterns) {
|
||||||
|
|
|
||||||
|
|
@ -369,6 +369,54 @@ static void print_rule(
|
||||||
fprintf(file, "\n");
|
fprintf(file, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Regex utilities
|
||||||
|
//
|
||||||
|
static llama_grammar_trigger_pattern llama_grammar_trigger_pattern_compile(const std::string & pattern) {
|
||||||
|
llama_grammar_trigger_pattern_type type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH;
|
||||||
|
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
||||||
|
// If anchored on both ends, consider it a match
|
||||||
|
type = LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH;
|
||||||
|
}
|
||||||
|
return {type, pattern, std::regex(pattern)};
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||||
|
auto find_start_pos = [](const std::smatch & match) {
|
||||||
|
// get from the first matched capturing group to the end of the string
|
||||||
|
size_t start = std::string::npos;
|
||||||
|
for (auto i = 1u; i < match.size(); i++) {
|
||||||
|
if (match.length(i) > 0) {
|
||||||
|
start = match.position(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
start = match.position(0);
|
||||||
|
}
|
||||||
|
return start;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH) {
|
||||||
|
// match against the entire input
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_match(input, match, regex)) {
|
||||||
|
return find_start_pos(match);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type == LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH) {
|
||||||
|
// search anywhere
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(input, match, regex)) {
|
||||||
|
return find_start_pos(match);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// implementation
|
// implementation
|
||||||
//
|
//
|
||||||
|
|
@ -1192,9 +1240,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
||||||
GGML_ASSERT(trigger_patterns != nullptr);
|
GGML_ASSERT(trigger_patterns != nullptr);
|
||||||
auto & trigger = vec_trigger_patterns.emplace_back();
|
vec_trigger_patterns.emplace_back(llama_grammar_trigger_pattern_compile(trigger_patterns[i]));
|
||||||
trigger.pattern = trigger_patterns[i];
|
|
||||||
trigger.regex = std::regex(trigger.pattern);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
|
@ -1312,21 +1358,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||||
grammar.trigger_buffer += piece;
|
grammar.trigger_buffer += piece;
|
||||||
|
|
||||||
std::smatch match;
|
|
||||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
auto start = trigger_pattern.find(grammar.trigger_buffer);
|
||||||
|
if (start != std::string::npos) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
// get from the first matched capturing group to the end of the string
|
|
||||||
size_t start = std::string::npos;
|
|
||||||
for (auto i = 1u; i < match.size(); i++) {
|
|
||||||
if (match.length(i) > 0) {
|
|
||||||
start = match.position(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (start == std::string::npos) {
|
|
||||||
start = match.position(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// replay tokens that overlap with [start, end)
|
// replay tokens that overlap with [start, end)
|
||||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||||
|
|
|
||||||
|
|
@ -116,9 +116,17 @@ struct llama_grammar_parser {
|
||||||
void print(FILE * file);
|
void print(FILE * file);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_grammar_trigger_pattern_type {
|
||||||
|
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_MATCH = 0,
|
||||||
|
LLAMA_GRAMMAR_TRIGGER_PATTERN_TYPE_SEARCH = 1,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_grammar_trigger_pattern {
|
struct llama_grammar_trigger_pattern {
|
||||||
|
llama_grammar_trigger_pattern_type type;
|
||||||
std::string pattern;
|
std::string pattern;
|
||||||
std::regex regex;
|
std::regex regex;
|
||||||
|
|
||||||
|
size_t find(const std::string & input) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue