more readable
This commit is contained in:
parent
3b195d301a
commit
c0b9903a1a
|
|
@ -349,9 +349,8 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
|
|
||||||
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
// (though it's technically the same as -1 now)
|
// (though it's technically the same as -1 now)
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17381
|
|
||||||
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
||||||
|
bool no_max = max_times == UINT64_MAX;
|
||||||
if (last_sym_start == rule.size()) {
|
if (last_sym_start == rule.size()) {
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
}
|
}
|
||||||
|
|
@ -384,14 +383,14 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
uint32_t last_rec_rule_id = 0;
|
||||||
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
|
auto n_opt = no_max ? 1 : max_times - min_times;
|
||||||
|
|
||||||
llama_grammar_rule rec_rule(prev_rule);
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
for (uint64_t i = 0; i < n_opt; i++) {
|
for (uint64_t i = 0; i < n_opt; i++) {
|
||||||
rec_rule.resize(prev_rule.size());
|
rec_rule.resize(prev_rule.size());
|
||||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
if (i > 0 || max_times == UINT64_MAX) {
|
if (i > 0 || no_max) {
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
|
||||||
}
|
}
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
|
@ -486,7 +485,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
pos = parse_space(int_end, is_nested);
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
uint64_t max_times = UINT64_MAX;
|
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||||
|
|
||||||
if (*pos == '}') {
|
if (*pos == '}') {
|
||||||
max_times = min_times;
|
max_times = min_times;
|
||||||
|
|
@ -507,7 +506,8 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
}
|
}
|
||||||
if (min_times > MAX_REPETITION_THRESHOLD || max_times > MAX_REPETITION_THRESHOLD) {
|
bool has_max = max_times != UINT64_MAX;
|
||||||
|
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||||
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||||
}
|
}
|
||||||
handle_repetitions(min_times, max_times);
|
handle_repetitions(min_times, max_times);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue