grammar: convert recursive llama_grammar_advance_stack to iterative
This change converts the function to an iterative approach using
explicit stacks, which prevents deep recursion and eliminates the risk
of stack overflow.
rg-edit regexp: llama_grammar_advance_stack
rg-edit extra-args: -A30
rg-edit directive: """Rewrite: fix the following segfault:
[..]
⚫ Testing segfault. Grammar:
root ::= ( [x]* )*
root ::= ( [x]* )*
Segmentation fault build/bin/test-grammar-integration
convert from recursive to interactive"""
gptel-context:
(("~/llama.cpp/src/llama-grammar.cpp")
("~/llama.cpp/tests/test-grammar-integration.cpp")
("~/llama.cpp/grammars/./list.gbnf")
("~/llama.cpp/grammars/./json_arr.gbnf")
("~/llama.cpp/grammars/./json.gbnf")
("~/llama.cpp/grammars/./japanese.gbnf")
("~/llama.cpp/grammars/./english.gbnf")
("~/llama.cpp/grammars/./chess.gbnf")
("~/llama.cpp/grammars/./c.gbnf")
("~/llama.cpp/grammars/./arithmetic.gbnf")
("~/llama.cpp/grammars/./README.md"))
v2: Added a `std::set` to perform tree-based lookups with O(N log N)
complexity. Testing with a parallel run of `test-grammar-integration`
shows a double-digit percentage increase in runtime. An
`unordered_set` with O(1) hashing was also evaluated, but the overhead
of constructing hash keys from pointers made it significantly slower
than the rbtree implementation that only requires an ordering
operator. The performance regression in the test suite appears
justified by the overall reduction in algorithmic complexity.
Co-developed-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
This commit is contained in:
parent
0e94a6c4e6
commit
fee44e673a
|
|
@ -7,6 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_REPETITION_THRESHOLD 2000
|
||||
|
|
@ -830,38 +831,54 @@ static bool llama_grammar_match_token(
|
|||
static void llama_grammar_advance_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
llama_grammar_stacks & new_stacks,
|
||||
llama_grammar_stacks & seen_stacks) {
|
||||
if (std::find(seen_stacks.begin(), seen_stacks.end(), stack) != seen_stacks.end()) {
|
||||
return;
|
||||
}
|
||||
seen_stacks.push_back(stack);
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
std::vector<llama_grammar_stack> todo;
|
||||
todo.push_back(stack);
|
||||
|
||||
if (stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(stack);
|
||||
auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) {
|
||||
return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(),
|
||||
[](const llama_grammar_element * pa, const llama_grammar_element * pb) {
|
||||
return pa < pb; // Compare pointer addresses
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
std::set<llama_grammar_stack, decltype(stack_cmp)> seen(stack_cmp);
|
||||
|
||||
while (!todo.empty()) {
|
||||
llama_grammar_stack curr_stack = std::move(todo.back());
|
||||
todo.pop_back();
|
||||
|
||||
if (seen.find( curr_stack) != seen.end()) {
|
||||
continue;
|
||||
}
|
||||
return;
|
||||
}
|
||||
seen.insert(curr_stack);
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
if (curr_stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (pos->type) {
|
||||
const llama_grammar_element * pos = curr_stack.back();
|
||||
|
||||
switch (pos->type) {
|
||||
case LLAMA_GRETYPE_RULE_REF: {
|
||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||
do {
|
||||
// init new stack without the top (pos)
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
// if this rule ref is followed by another element, add that to stack
|
||||
new_stack.push_back(pos + 1);
|
||||
next_stack.push_back(pos + 1);
|
||||
}
|
||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// if alternate is nonempty, add to stack
|
||||
new_stack.push_back(subpos);
|
||||
next_stack.push_back(subpos);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, new_stacks, seen_stacks);
|
||||
todo.push_back(std::move(next_stack));
|
||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// scan to end of alternate def
|
||||
subpos++;
|
||||
|
|
@ -880,9 +897,9 @@ static void llama_grammar_advance_stack(
|
|||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(stack);
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
|
@ -890,6 +907,7 @@ static void llama_grammar_advance_stack(
|
|||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||
// those
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -995,8 +1013,7 @@ static void llama_grammar_accept_chr(
|
|||
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||
new_stack.push_back(match.second);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1072,8 +1089,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||
stack_after.push_back(stack_pos_after);
|
||||
}
|
||||
llama_grammar_stacks next_stacks;
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(rules, stack_after, next_stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
||||
|
||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||
for (const auto & tok : next_rejects) {
|
||||
|
|
@ -1124,8 +1140,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// if alternate is nonempty, add to stack
|
||||
stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
// scan to end of alternate def
|
||||
pos++;
|
||||
|
|
@ -1218,8 +1233,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// if alternate is nonempty, add to stack
|
||||
stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks, seen_stacks);
|
||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
// scan to end of alternate def
|
||||
pos++;
|
||||
|
|
@ -1438,8 +1452,7 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke
|
|||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_stacks seen_stacks;
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new, seen_stacks);
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
|
|
|||
|
|
@ -123,25 +123,27 @@ int main()
|
|||
|
||||
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
|
|
@ -149,26 +151,24 @@ int main()
|
|||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
}};
|
||||
|
||||
auto index = 0;
|
||||
|
|
@ -195,9 +195,9 @@ int main()
|
|||
}
|
||||
|
||||
std::vector<llama_grammar_candidate> next_candidates;
|
||||
next_candidates.resize(24);
|
||||
next_candidates.resize(23);
|
||||
|
||||
for (size_t i = 0; i < 24; ++i)
|
||||
for (size_t i = 0; i < 23; ++i)
|
||||
{
|
||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||
cp[0] = 37 + i;
|
||||
|
|
@ -210,7 +210,6 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -268,6 +267,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -287,13 +287,11 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
{
|
||||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -351,6 +349,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -370,7 +369,6 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue