diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 64ea2fd00a..321d26f3df 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -10,6 +10,7 @@ #include #define MAX_REPETITION_THRESHOLD 2000 +static constexpr uint32_t MAX_GRAMMAR_RECURSION_DEPTH = 2000; // // helpers // @@ -830,7 +831,16 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + std::vector & history) { + if (history.size() >= MAX_GRAMMAR_RECURSION_DEPTH) { + return; + } + + if (std::find(history.begin(), history.end(), stack) != history.end()) { + return; + } + if (stack.empty()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { new_stacks.emplace_back(stack); @@ -855,7 +865,9 @@ static void llama_grammar_advance_stack( // if alternate is nonempty, add to stack new_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + history.push_back(stack); + llama_grammar_advance_stack(rules, new_stack, new_stacks, history); + history.pop_back(); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -887,6 +899,14 @@ static void llama_grammar_advance_stack( } } +static void llama_grammar_advance_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + std::vector history; + llama_grammar_advance_stack(rules, stack, new_stacks, history); +} + static llama_grammar_candidates llama_grammar_reject_candidates( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks,