diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 64ea2fd00a..1af066e9f6 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -830,7 +830,13 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + llama_grammar_stacks & seen_stacks) { + if (std::find(seen_stacks.begin(), seen_stacks.end(), stack) != seen_stacks.end()) { + return; + } + seen_stacks.push_back(stack); + if (stack.empty()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { new_stacks.emplace_back(stack); @@ -855,7 +861,7 @@ static void llama_grammar_advance_stack( // if alternate is nonempty, add to stack new_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, new_stacks, seen_stacks); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -989,7 +995,8 @@ static void llama_grammar_accept_chr( if (!llama_grammar_is_end_of_sequence(match.second)) { new_stack.push_back(match.second); } - llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + llama_grammar_stacks seen_stacks; + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks, seen_stacks); } } @@ -1065,7 +1072,8 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( stack_after.push_back(stack_pos_after); } llama_grammar_stacks next_stacks; - llama_grammar_advance_stack(rules, stack_after, next_stacks); + llama_grammar_stacks seen_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks, seen_stacks); auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { @@ -1116,7 +1124,8 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_stacks seen_stacks; + llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1209,7 +1218,8 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_stacks seen_stacks; + llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1428,7 +1438,8 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke if (!llama_grammar_is_end_of_sequence(pos + 1)) { new_stack.push_back(pos + 1); } - llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + llama_grammar_stacks seen_stacks; + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new, seen_stacks); } } else { llama_grammar_stacks current_stacks = {stack};