Merge 43d9e59d0a into 06bf3796f4
This commit is contained in:
commit
de36a357a8
|
|
@ -454,6 +454,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
bool is_nested) {
|
||||
size_t last_sym_start = rule.size();
|
||||
const char * pos = src;
|
||||
uint64_t n_prev_rules = 1;
|
||||
|
||||
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||
// (though it's technically the same as -1 now)
|
||||
|
|
@ -481,6 +482,18 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
// S' ::= S |
|
||||
|
||||
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||
// Calculate the total number of rules that will be generated by this repetition
|
||||
uint64_t total_rules = 1; // Start with 1 for the original rule
|
||||
if (!no_max && max_times > 0) {
|
||||
total_rules = max_times;
|
||||
} else if (min_times > 0) {
|
||||
total_rules = min_times;
|
||||
}
|
||||
|
||||
if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) {
|
||||
throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity");
|
||||
}
|
||||
|
||||
if (min_times == 0) {
|
||||
rule.resize(last_sym_start);
|
||||
} else {
|
||||
|
|
@ -508,12 +521,15 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
if (n_opt > 0) {
|
||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||
}
|
||||
n_prev_rules *= total_rules;
|
||||
GGML_ASSERT(n_prev_rules >= 1);
|
||||
};
|
||||
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = rule.size();
|
||||
n_prev_rules = 1;
|
||||
while (*pos != '"') {
|
||||
if (!*pos) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
|
|
@ -531,6 +547,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = rule.size();
|
||||
n_prev_rules = 1;
|
||||
while (*pos != ']') {
|
||||
if (!*pos) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
|
|
@ -561,6 +578,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
auto token_pair = parse_token(vocab, pos);
|
||||
const char * token_end = token_pair.second;
|
||||
last_sym_start = rule.size();
|
||||
n_prev_rules = 1;
|
||||
rule.push_back({type, token_pair.first});
|
||||
pos = parse_space(token_end, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
|
|
@ -568,12 +586,15 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = rule.size();
|
||||
n_prev_rules = 1;
|
||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t n_rules_before = symbol_ids.size();
|
||||
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||
n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before);
|
||||
last_sym_start = rule.size();
|
||||
// output reference to synthesized rule
|
||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
|
@ -583,6 +604,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '.') { // any char
|
||||
last_sym_start = rule.size();
|
||||
n_prev_rules = 1;
|
||||
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*') {
|
||||
|
|
@ -831,59 +853,74 @@ static void llama_grammar_advance_stack(
|
|||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
if (stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(stack);
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::vector<llama_grammar_stack> todo;
|
||||
todo.push_back(stack);
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
std::vector<llama_grammar_stack> seen;
|
||||
|
||||
switch (pos->type) {
|
||||
case LLAMA_GRETYPE_RULE_REF: {
|
||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||
do {
|
||||
// init new stack without the top (pos)
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
// if this rule ref is followed by another element, add that to stack
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// if alternate is nonempty, add to stack
|
||||
new_stack.push_back(subpos);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// scan to end of alternate def
|
||||
subpos++;
|
||||
}
|
||||
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||
// there's another alternate def of this rule to process
|
||||
subpos++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
break;
|
||||
while (!todo.empty()) {
|
||||
llama_grammar_stack curr_stack = std::move(todo.back());
|
||||
todo.pop_back();
|
||||
|
||||
if (std::find(seen.begin(), seen.end(), curr_stack) != seen.end()) {
|
||||
continue;
|
||||
}
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(stack);
|
||||
seen.push_back(curr_stack);
|
||||
|
||||
if (curr_stack.empty()) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||
// those
|
||||
GGML_ABORT("fatal error");
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = curr_stack.back();
|
||||
|
||||
switch (pos->type) {
|
||||
case LLAMA_GRETYPE_RULE_REF: {
|
||||
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||
do {
|
||||
// init new stack without the top (pos)
|
||||
llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
// if this rule ref is followed by another element, add that to stack
|
||||
next_stack.push_back(pos + 1);
|
||||
}
|
||||
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// if alternate is nonempty, add to stack
|
||||
next_stack.push_back(subpos);
|
||||
}
|
||||
todo.push_back(std::move(next_stack));
|
||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||
// scan to end of alternate def
|
||||
subpos++;
|
||||
}
|
||||
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||
// there's another alternate def of this rule to process
|
||||
subpos++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} while (true);
|
||||
break;
|
||||
}
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(std::move(curr_stack));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||
// those
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -784,6 +784,24 @@ static void test_quantifiers() {
|
|||
"0xFF 0x12 0xAB 0x00 0x00 0x00",
|
||||
}
|
||||
);
|
||||
test_grammar(
|
||||
"segfault",
|
||||
// Grammar
|
||||
R"""(
|
||||
root ::= ( [x]* )*
|
||||
)""",
|
||||
// Passing strings
|
||||
{
|
||||
"",
|
||||
"x",
|
||||
"xx"
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"y",
|
||||
"yy"
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void test_failure_missing_root() {
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ int main()
|
|||
root ::= "a"{,}"
|
||||
)""");
|
||||
|
||||
verify_failure(R"""(
|
||||
root ::= (((((([^x]*){0,99}){0,99}){0,99}){0,99}){0,99}){0,99}
|
||||
)""");
|
||||
|
||||
verify_failure(R"""(
|
||||
root ::= "a"{,10}"
|
||||
)""");
|
||||
|
|
|
|||
|
|
@ -123,25 +123,27 @@ int main()
|
|||
|
||||
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
|
|
@ -149,26 +151,24 @@ int main()
|
|||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_RULE_REF, 5},
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 97},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_RULE_REF, 3},
|
||||
{LLAMA_GRETYPE_CHAR, 48},
|
||||
},
|
||||
{
|
||||
{LLAMA_GRETYPE_CHAR, 61},
|
||||
{LLAMA_GRETYPE_RULE_REF, 7},
|
||||
{LLAMA_GRETYPE_CHAR, 40},
|
||||
}};
|
||||
|
||||
auto index = 0;
|
||||
|
|
@ -195,9 +195,9 @@ int main()
|
|||
}
|
||||
|
||||
std::vector<llama_grammar_candidate> next_candidates;
|
||||
next_candidates.resize(24);
|
||||
next_candidates.resize(23);
|
||||
|
||||
for (size_t i = 0; i < 24; ++i)
|
||||
for (size_t i = 0; i < 23; ++i)
|
||||
{
|
||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||
cp[0] = 37 + i;
|
||||
|
|
@ -210,7 +210,6 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -268,6 +267,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -287,13 +287,11 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
{
|
||||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -351,6 +349,7 @@ int main()
|
|||
{0, 37},
|
||||
{1, 38},
|
||||
{2, 39},
|
||||
{3, 40},
|
||||
{4, 41},
|
||||
{5, 42},
|
||||
{6, 43},
|
||||
|
|
@ -370,7 +369,6 @@ int main()
|
|||
{20, 57},
|
||||
{21, 58},
|
||||
{22, 59},
|
||||
{23, 60},
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue