diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b177f33e6b..3f1c99ba6c 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -455,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char * pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -482,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } else { @@ -509,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -532,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -562,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -569,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); @@ -584,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence( pos = parse_space(pos + 1, is_nested); } else if (*pos == '.') { // any char last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 22d2799dd1..2728d32ab4 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -802,19 +802,6 @@ static void test_quantifiers() { "yy" } ); - test_grammar( - "hang", - // Grammar - R"""( - root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99} - )""", - // Passing strings - { - }, - // Failing strings - { - } - ); } static void test_failure_missing_root() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 03ae78ff73..6abc43461b 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -145,6 +145,10 @@ int main() root ::= "a"{,}" )"""); + verify_failure(R"""( + root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99} + )"""); + verify_failure(R"""( root ::= "a"{,10}" )""");