Merge 74e6235f5e into b83111815e
This commit is contained in:
commit
d55eb2ac51
|
|
@ -9,7 +9,10 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#define MAX_REPETITION_THRESHOLD 2000
|
#define MAX_REPETITION_THRESHOLD 2000
|
||||||
|
#define MAX_GRAMMAR_RECURSION_DEPTH 2000
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
|
|
@ -825,12 +828,32 @@ static bool llama_grammar_match_token(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hash function for llama_grammar_stack to enable std::unordered_set usage
|
||||||
|
struct llama_grammar_stack_hash {
|
||||||
|
std::size_t operator()(const llama_grammar_stack & stack) const {
|
||||||
|
std::size_t hash = 0;
|
||||||
|
for (const auto * elem : stack) {
|
||||||
|
hash ^= std::hash<const void*>{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
llama_grammar_stacks & new_stacks) {
|
llama_grammar_stacks & new_stacks,
|
||||||
|
std::unordered_set<llama_grammar_stack, llama_grammar_stack_hash> & history) {
|
||||||
|
if (history.size() >= MAX_GRAMMAR_RECURSION_DEPTH) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (history.count(stack) > 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
new_stacks.emplace_back(stack);
|
new_stacks.emplace_back(stack);
|
||||||
|
|
@ -855,7 +878,9 @@ static void llama_grammar_advance_stack(
|
||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
new_stack.push_back(subpos);
|
new_stack.push_back(subpos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
history.insert(stack);
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks, history);
|
||||||
|
history.erase(stack);
|
||||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
subpos++;
|
subpos++;
|
||||||
|
|
@ -887,6 +912,14 @@ static void llama_grammar_advance_stack(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_grammar_advance_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
std::unordered_set<llama_grammar_stack, llama_grammar_stack_hash> history;
|
||||||
|
llama_grammar_advance_stack(rules, stack, new_stacks, history);
|
||||||
|
}
|
||||||
|
|
||||||
static llama_grammar_candidates llama_grammar_reject_candidates(
|
static llama_grammar_candidates llama_grammar_reject_candidates(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
|
|
|
||||||
|
|
@ -784,6 +784,23 @@ static void test_quantifiers() {
|
||||||
"0xFF 0x12 0xAB 0x00 0x00 0x00",
|
"0xFF 0x12 0xAB 0x00 0x00 0x00",
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
test_grammar(
|
||||||
|
"nested repetition",
|
||||||
|
// Grammar
|
||||||
|
R"""(root ::= ("a"* )*)""",
|
||||||
|
// Passing strings
|
||||||
|
{
|
||||||
|
"",
|
||||||
|
"a",
|
||||||
|
"aa",
|
||||||
|
"aaa",
|
||||||
|
},
|
||||||
|
// Failing strings
|
||||||
|
{
|
||||||
|
"b",
|
||||||
|
"ab",
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_failure_missing_root() {
|
static void test_failure_missing_root() {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue