From b689ff4779f9e5ca839bc9d1d61172d299b48762 Mon Sep 17 00:00:00 2001 From: Andrea Arcangeli Date: Fri, 2 Jan 2026 17:42:31 +0100 Subject: [PATCH] grammar: convert recursive llama_grammar_advance_stack to iterative MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change converts the function to an iterative approach using explicit stacks, which prevents deep recursion and eliminates the risk of stack overflow. rg-edit regexp: llama_grammar_advance_stack rg-edit extra-args: -A30 rg-edit directive: """Rewrite: fix the following segfault: [..] ⚫ Testing segfault. Grammar: root ::= ( [x]* )* root ::= ( [x]* )* Segmentation fault build/bin/test-grammar-integration convert from recursive to interactive""" gptel-context Value: (("~/devel/ai/llama.cpp/src/llama-grammar.cpp") ("~/devel/ai/llama.cpp/tests/test-grammar-integration.cpp") ("~/devel/ai/llama.cpp/grammars/./list.gbnf") ("~/devel/ai/llama.cpp/grammars/./json_arr.gbnf") ("~/devel/ai/llama.cpp/grammars/./json.gbnf") ("~/devel/ai/llama.cpp/grammars/./japanese.gbnf") ("~/devel/ai/llama.cpp/grammars/./english.gbnf") ("~/devel/ai/llama.cpp/grammars/./chess.gbnf") ("~/devel/ai/llama.cpp/grammars/./c.gbnf") ("~/devel/ai/llama.cpp/grammars/./arithmetic.gbnf") ("~/devel/ai/llama.cpp/grammars/./README.md")) --- src/llama-grammar.cpp | 136 ++++++++++++++++++----------------- tests/test-llama-grammar.cpp | 74 ++++++++++--------- 2 files changed, 106 insertions(+), 104 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 1af066e9f6..ec8bb46527 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -830,66 +830,75 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks, - llama_grammar_stacks & seen_stacks) { - if (std::find(seen_stacks.begin(), seen_stacks.end(), stack) != seen_stacks.end()) { - return; - } - seen_stacks.push_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector todo; + todo.push_back(stack); - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + std::vector seen; + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (std::find(seen.begin(), seen.end(), curr_stack) != seen.end()) { + continue; } - return; - } + seen.push_back(curr_stack); - const llama_grammar_element * pos = stack.back(); - - switch (pos->type) { - case LLAMA_GRETYPE_RULE_REF: { - const size_t rule_id = static_cast(pos->value); - const llama_grammar_element * subpos = rules[rule_id].data(); - do { - // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos + 1)) { - // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); - } - if (!llama_grammar_is_end_of_sequence(subpos)) { - // if alternate is nonempty, add to stack - new_stack.push_back(subpos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks, seen_stacks); - while (!llama_grammar_is_end_of_sequence(subpos)) { - // scan to end of alternate def - subpos++; - } - if (subpos->type == LLAMA_GRETYPE_ALT) { - // there's another alternate def of this rule to process - subpos++; - } else { - break; - } - } while (true); - break; - } - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_ANY: - case LLAMA_GRETYPE_TOKEN: - case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); } - break; - default: - // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range - // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on - // those - GGML_ABORT("fatal error"); + continue; + } + + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + next_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + next_stack.push_back(subpos); + } + todo.push_back(std::move(next_stack)); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(std::move(curr_stack)); + } + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ABORT("fatal error"); + } } } @@ -995,8 +1004,7 @@ static void llama_grammar_accept_chr( if (!llama_grammar_is_end_of_sequence(match.second)) { new_stack.push_back(match.second); } - llama_grammar_stacks seen_stacks; - llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks, seen_stacks); + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); } } @@ -1072,8 +1080,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( stack_after.push_back(stack_pos_after); } llama_grammar_stacks next_stacks; - llama_grammar_stacks seen_stacks; - llama_grammar_advance_stack(rules, stack_after, next_stacks, seen_stacks); + llama_grammar_advance_stack(rules, stack_after, next_stacks); auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { @@ -1124,8 +1131,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_stacks seen_stacks; - llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks); + llama_grammar_advance_stack(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1218,8 +1224,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_stacks seen_stacks; - llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks); + llama_grammar_advance_stack(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1438,8 +1443,7 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke if (!llama_grammar_is_end_of_sequence(pos + 1)) { new_stack.push_back(pos + 1); } - llama_grammar_stacks seen_stacks; - llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new, seen_stacks); + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); } } else { llama_grammar_stacks current_stacks = {stack}; diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index fd45d5ada8..25f432a2f5 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -123,25 +123,27 @@ int main() std::vector> expected_stacks = { { - {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_CHAR, 40}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, }, - { - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, @@ -149,26 +151,24 @@ int main() {LLAMA_GRETYPE_CHAR, 40}, }, { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 48}, + }, + { + {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 48}, - }, - { - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_CHAR, 40}, }}; auto index = 0; @@ -195,9 +195,9 @@ int main() } std::vector next_candidates; - next_candidates.resize(24); + next_candidates.resize(23); - for (size_t i = 0; i < 24; ++i) + for (size_t i = 0; i < 23; ++i) { uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; @@ -210,7 +210,6 @@ int main() {0, 37}, {1, 38}, {2, 39}, - {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -268,6 +267,7 @@ int main() {0, 37}, {1, 38}, {2, 39}, + {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -287,13 +287,11 @@ int main() {20, 57}, {21, 58}, {22, 59}, - {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, - {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -351,6 +349,7 @@ int main() {0, 37}, {1, 38}, {2, 39}, + {3, 40}, {4, 41}, {5, 42}, {6, 43}, @@ -370,7 +369,6 @@ int main() {20, 57}, {21, 58}, {22, 59}, - {23, 60}, }, };