grammar: add repetition threshold check
The change introduces a maximum repetition threshold to avoid
excessive rule expansion during grammar parsing. When parsing
repetition patterns like {m,n}, the parser now calculates the
potential number of rules that would be generated and throws an error
if the product of previous rules and new rules exceeds the threshold.
A test case was added to verify the threshold is properly enforced for
deeply nested repetition patterns that would otherwise cause hangs.
This commit is contained in:
parent
072efa22a2
commit
13c8d22f3e
|
|
@ -455,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
bool is_nested) {
|
bool is_nested) {
|
||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char * pos = src;
|
const char * pos = src;
|
||||||
|
uint64_t n_prev_rules = 1;
|
||||||
|
|
||||||
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
// (though it's technically the same as -1 now)
|
// (though it's technically the same as -1 now)
|
||||||
|
|
@ -482,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
// S' ::= S |
|
// S' ::= S |
|
||||||
|
|
||||||
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||||
|
// Calculate the total number of rules that will be generated by this repetition
|
||||||
|
uint64_t total_rules = 1; // Start with 1 for the original rule
|
||||||
|
if (!no_max && max_times > 0) {
|
||||||
|
total_rules = max_times;
|
||||||
|
} else if (min_times > 0) {
|
||||||
|
total_rules = min_times;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) {
|
||||||
|
throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity");
|
||||||
|
}
|
||||||
|
|
||||||
if (min_times == 0) {
|
if (min_times == 0) {
|
||||||
rule.resize(last_sym_start);
|
rule.resize(last_sym_start);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -509,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
if (n_opt > 0) {
|
if (n_opt > 0) {
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||||
}
|
}
|
||||||
|
n_prev_rules *= total_rules;
|
||||||
|
GGML_ASSERT(n_prev_rules >= 1);
|
||||||
};
|
};
|
||||||
|
|
||||||
while (*pos) {
|
while (*pos) {
|
||||||
if (*pos == '"') { // literal string
|
if (*pos == '"') { // literal string
|
||||||
pos++;
|
pos++;
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
while (*pos != '"') {
|
while (*pos != '"') {
|
||||||
if (!*pos) {
|
if (!*pos) {
|
||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
|
@ -532,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
}
|
}
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
while (*pos != ']') {
|
while (*pos != ']') {
|
||||||
if (!*pos) {
|
if (!*pos) {
|
||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
|
@ -562,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
auto token_pair = parse_token(vocab, pos);
|
auto token_pair = parse_token(vocab, pos);
|
||||||
const char * token_end = token_pair.second;
|
const char * token_end = token_pair.second;
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
rule.push_back({type, token_pair.first});
|
rule.push_back({type, token_pair.first});
|
||||||
pos = parse_space(token_end, is_nested);
|
pos = parse_space(token_end, is_nested);
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
|
@ -569,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
pos = parse_space(name_end, is_nested);
|
pos = parse_space(name_end, is_nested);
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
} else if (*pos == '(') { // grouping
|
} else if (*pos == '(') { // grouping
|
||||||
// parse nested alternates into synthesized rule
|
// parse nested alternates into synthesized rule
|
||||||
pos = parse_space(pos + 1, true);
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t n_rules_before = symbol_ids.size();
|
||||||
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||||
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||||
|
n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before);
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
// output reference to synthesized rule
|
// output reference to synthesized rule
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
|
@ -584,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
} else if (*pos == '.') { // any char
|
} else if (*pos == '.') { // any char
|
||||||
last_sym_start = rule.size();
|
last_sym_start = rule.size();
|
||||||
|
n_prev_rules = 1;
|
||||||
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
} else if (*pos == '*') {
|
} else if (*pos == '*') {
|
||||||
|
|
|
||||||
|
|
@ -802,19 +802,6 @@ static void test_quantifiers() {
|
||||||
"yy"
|
"yy"
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
test_grammar(
|
|
||||||
"hang",
|
|
||||||
// Grammar
|
|
||||||
R"""(
|
|
||||||
root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99}
|
|
||||||
)""",
|
|
||||||
// Passing strings
|
|
||||||
{
|
|
||||||
},
|
|
||||||
// Failing strings
|
|
||||||
{
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_failure_missing_root() {
|
static void test_failure_missing_root() {
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,10 @@ int main()
|
||||||
root ::= "a"{,}"
|
root ::= "a"{,}"
|
||||||
)""");
|
)""");
|
||||||
|
|
||||||
|
verify_failure(R"""(
|
||||||
|
root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99}
|
||||||
|
)""");
|
||||||
|
|
||||||
verify_failure(R"""(
|
verify_failure(R"""(
|
||||||
root ::= "a"{,10}"
|
root ::= "a"{,10}"
|
||||||
)""");
|
)""");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue