grammar: convert recursive llama_grammar_advance_stack to iterative
This change converts the function to an iterative approach using
explicit stacks, which prevents deep recursion and eliminates the risk
of stack overflow.
rg-edit regexp: llama_grammar_advance_stack
rg-edit extra-args: -A30
rg-edit directive: """Rewrite: fix the following segfault:
[..]
⚫ Testing segfault. Grammar:
root ::= ( [x]* )*
root ::= ( [x]* )*
Segmentation fault build/bin/test-grammar-integration
convert from recursive to interactive"""
gptel-context Value:
(("~/devel/ai/llama.cpp/src/llama-grammar.cpp")
("~/devel/ai/llama.cpp/tests/test-grammar-integration.cpp")
("~/devel/ai/llama.cpp/grammars/./list.gbnf")
("~/devel/ai/llama.cpp/grammars/./json_arr.gbnf")
("~/devel/ai/llama.cpp/grammars/./json.gbnf")
("~/devel/ai/llama.cpp/grammars/./japanese.gbnf")
("~/devel/ai/llama.cpp/grammars/./english.gbnf")
("~/devel/ai/llama.cpp/grammars/./chess.gbnf")
("~/devel/ai/llama.cpp/grammars/./c.gbnf")
("~/devel/ai/llama.cpp/grammars/./arithmetic.gbnf")
("~/devel/ai/llama.cpp/grammars/./README.md"))
This commit is contained in:
parent
e289f380bf
commit
b689ff4779
|
|
@ -830,66 +830,75 @@ static bool llama_grammar_match_token(
|
|||
static void llama_grammar_advance_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
llama_grammar_stacks & new_stacks,
|
||||
llama_grammar_stacks & seen_stacks) {
|
||||
if (std::find(seen_stacks.begin(), seen_stacks.end(), stack) != seen_stacks.end()) {
|
||||
return;
|
||||
}
|
||||
seen_stacks.push_back(stack);
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
std::vector<llama_grammar_stack> todo;
|
||||
todo.push_back(stack);
|
||||
|
||||
if (stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(stack);
|
||||
std::vector<llama_grammar_stack> seen;
|
||||
|
||||
while (!todo.empty()) {
|
||||
llama_grammar_stack curr_stack = std::move(todo.back());
|
||||
todo.pop_back();
|
||||
|
||||
if (std::find(seen.begin(), seen.end(), curr_stack) != seen.end()) {
|
||||
continue;
|
||||
}
|
||||
return;
|
||||
}
|
||||
seen.push_back(curr_stack);
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
switch (pos->type) {
|
||||
case LLAMA_GRETYPE_RULE_REF: {
|
||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||
do {
|
||||
// init new stack without the top (pos)
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
// if this rule ref is followed by another element, add that to stack
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// if alternate is nonempty, add to stack
|
||||
new_stack.push_back(subpos);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, new_stacks, seen_stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// scan to end of alternate def
|
||||
subpos++;
|
||||
}
|
||||
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||
// there's another alternate def of this rule to process
|
||||
subpos++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
break;
|
||||
}
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(stack);
|
||||
if (curr_stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||
// those
|
||||
GGML_ABORT("fatal error");
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = curr_stack.back();
|
||||
|
||||
switch (pos->type) {
|
||||
case LLAMA_GRETYPE_RULE_REF: {
|
||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||
do {
|
||||
// init new stack without the top (pos)
|
||||
llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
// if this rule ref is followed by another element, add that to stack
|
||||
next_stack.push_back(pos + 1);
|
||||
}
|
||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// if alternate is nonempty, add to stack
|
||||
next_stack.push_back(subpos);
|
||||
}
|
||||
todo.push_back(std::move(next_stack));
|
||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// scan to end of alternate def
|
||||
subpos++;
|
||||
}
|
||||
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||
// there's another alternate def of this rule to process
|
||||
subpos++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
break;
|
||||
}
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||
// those
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -995,8 +1004,7 @@ static void llama_grammar_accept_chr(
|
|||
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||
new_stack.push_back(match.second);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1072,8 +1080,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||
stack_after.push_back(stack_pos_after);
|
||||
}
|
||||
llama_grammar_stacks next_stacks;
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(rules, stack_after, next_stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
||||
|
||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||
for (const auto & tok : next_rejects) {
|
||||
|
|
@ -1124,8 +1131,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// if alternate is nonempty, add to stack
|
||||
stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
// scan to end of alternate def
|
||||
pos++;
|
||||
|
|
@ -1218,8 +1224,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// if alternate is nonempty, add to stack
|
||||
stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
// scan to end of alternate def
|
||||
pos++;
|
||||
|
|
@ -1438,8 +1443,7 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke
|
|||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new, seen_stacks);
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
|
|
|||
|
|
@ -123,25 +123,27 @@ int main()
|
|||
|
||||
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
|
|
@ -149,26 +151,24 @@ int main()
|
|||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
}};
|
||||
|
||||
auto index = 0;
|
||||
|
|
@ -195,9 +195,9 @@ int main()
|
|||
}
|
||||
|
||||
std::vector<llama_grammar_candidate> next_candidates;
|
||||
next_candidates.resize(24);
|
||||
next_candidates.resize(23);
|
||||
|
||||
for (size_t i = 0; i < 24; ++i)
|
||||
for (size_t i = 0; i < 23; ++i)
|
||||
{
|
||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||
cp[0] = 37 + i;
|
||||
|
|
@ -210,7 +210,6 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -268,6 +267,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -287,13 +287,11 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
{
|
||||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -351,6 +349,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -370,7 +369,6 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue