grammar: fix regression caused by #17381

This commit is contained in:
Xuan Son Nguyen 2025-11-20 17:31:55 +01:00
parent 4c91f2633f
commit 3b195d301a
1 changed files with 8 additions and 7 deletions

View File

@ -347,9 +347,10 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size(); size_t last_sym_start = rule.size();
const char * pos = src; const char * pos = src;
// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
// (though it's technically the same as -1 now) // (though it's technically the same as -1 now)
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) { // ref: https://github.com/ggml-org/llama.cpp/pull/17381
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
if (last_sym_start == rule.size()) { if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
@ -377,7 +378,7 @@ const char * llama_grammar_parser::parse_sequence(
rule.resize(last_sym_start); rule.resize(last_sym_start);
} else { } else {
// Repeat the previous elements (min_times - 1) times // Repeat the previous elements (min_times - 1) times
for (unsigned long i = 1; i < min_times; i++) { for (uint64_t i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
} }
} }
@ -386,7 +387,7 @@ const char * llama_grammar_parser::parse_sequence(
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times; auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule); llama_grammar_rule rec_rule(prev_rule);
for (unsigned long i = 0; i < n_opt; i++) { for (uint64_t i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size()); rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name); uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times == UINT64_MAX) { if (i > 0 || max_times == UINT64_MAX) {
@ -482,10 +483,10 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos); throw std::runtime_error(std::string("expecting an int at ") + pos);
} }
const char * int_end = parse_int(pos); const char * int_end = parse_int(pos);
unsigned long min_times = std::stoul(std::string(pos, int_end - pos)); uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested); pos = parse_space(int_end, is_nested);
unsigned long max_times = UINT64_MAX; uint64_t max_times = UINT64_MAX;
if (*pos == '}') { if (*pos == '}') {
max_times = min_times; max_times = min_times;
@ -506,7 +507,7 @@ const char * llama_grammar_parser::parse_sequence(
} else { } else {
throw std::runtime_error(std::string("expecting ',' at ") + pos); throw std::runtime_error(std::string("expecting ',' at ") + pos);
} }
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) { if (min_times > MAX_REPETITION_THRESHOLD || max_times > MAX_REPETITION_THRESHOLD) {
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions")); throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
} }
handle_repetitions(min_times, max_times); handle_repetitions(min_times, max_times);