diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2d55070cec..c8eed1c053 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -9,7 +9,10 @@ #include #include +#include + #define MAX_REPETITION_THRESHOLD 2000 +#define MAX_GRAMMAR_RECURSION_DEPTH 2000 // // helpers // @@ -825,12 +828,32 @@ static bool llama_grammar_match_token( return false; } +// Hash function for llama_grammar_stack to enable std::unordered_set usage +struct llama_grammar_stack_hash { + std::size_t operator()(const llama_grammar_stack & stack) const { + std::size_t hash = 0; + for (const auto * elem : stack) { + hash ^= std::hash{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } +}; + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + std::unordered_set & history) { + if (history.size() >= MAX_GRAMMAR_RECURSION_DEPTH) { + return; + } + + if (history.count(stack) > 0) { + return; + } + if (stack.empty()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { new_stacks.emplace_back(stack); @@ -855,7 +878,9 @@ static void llama_grammar_advance_stack( // if alternate is nonempty, add to stack new_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + history.insert(stack); + llama_grammar_advance_stack(rules, new_stack, new_stacks, history); + history.erase(stack); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -887,6 +912,14 @@ static void llama_grammar_advance_stack( } } +static void llama_grammar_advance_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + std::unordered_set history; + llama_grammar_advance_stack(rules, stack, new_stacks, history); +} + static llama_grammar_candidates llama_grammar_reject_candidates( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 7aa7e58a5c..f8b3a028db 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -784,6 +784,23 @@ static void test_quantifiers() { "0xFF 0x12 0xAB 0x00 0x00 0x00", } ); + test_grammar( + "nested repetition", + // Grammar + R"""(root ::= ("a"* )*)""", + // Passing strings + { + "", + "a", + "aa", + "aaa", + }, + // Failing strings + { + "b", + "ab", + } + ); } static void test_failure_missing_root() {