This commit is contained in:
Elias Oenal 2026-02-16 23:55:39 +02:00 committed by GitHub
commit ceaf729db0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 122 additions and 66 deletions

View File

@ -968,7 +968,7 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
}
static void llama_grammar_accept_chr(
struct llama_grammar & grammar,
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
uint32_t chr,
llama_grammar_stacks & new_stacks) {
@ -989,7 +989,7 @@ static void llama_grammar_accept_chr(
if (!llama_grammar_is_end_of_sequence(match.second)) {
new_stack.push_back(match.second);
}
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}
@ -998,7 +998,7 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
stacks_new.reserve(grammar->stacks.size());
for (const auto & stack : grammar->stacks) {
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
llama_grammar_accept_chr(grammar->rules, stack, chr, stacks_new);
}
grammar->stacks = std::move(stacks_new);
@ -1333,6 +1333,77 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
}
// compute new stacks after accepting a token, without mutating the grammar
// returns true if the token is valid (non-empty stacks)
static bool llama_grammar_try_accept_token(
const llama_grammar_rules & rules,
llama_token token,
const std::string & piece,
const llama_grammar_stacks & stacks_cur,
const llama_partial_utf8 & partial_utf8_cur,
llama_grammar_stacks & new_stacks,
llama_partial_utf8 & new_partial_utf8) {
const auto decoded = decode_utf8(piece, partial_utf8_cur);
const auto & code_points = decoded.first;
new_stacks.clear();
new_stacks.reserve(stacks_cur.size());
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
continue;
}
const llama_grammar_element * pos = stack.back();
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
if (llama_grammar_match_token(pos, token)) {
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
new_stack.push_back(pos + 1);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
} else {
llama_grammar_stacks current_stacks = {stack};
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks next_stacks;
for (const auto & cur_stack : current_stacks) {
llama_grammar_accept_chr(rules, cur_stack, *it, next_stacks);
}
current_stacks = std::move(next_stacks);
if (current_stacks.empty()) {
break;
}
}
for (auto & surviving_stack : current_stacks) {
if (std::find(new_stacks.begin(), new_stacks.end(), surviving_stack) == new_stacks.end()) {
new_stacks.emplace_back(surviving_stack);
}
}
}
}
new_partial_utf8 = decoded.second;
return !new_stacks.empty();
}
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
llama_grammar_stacks stacks_new;
llama_partial_utf8 partial_utf8_new;
if (!llama_grammar_try_accept_token(grammar.rules, token, piece, grammar.stacks, grammar.partial_utf8, stacks_new, partial_utf8_new)) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
}
grammar.stacks = std::move(stacks_new);
grammar.partial_utf8 = partial_utf8_new;
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
@ -1343,7 +1414,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();
llama_grammar_accept_token(grammar, token, piece);
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
LLAMA_LOG_DEBUG("%s: Grammar triggered on token %u (`%s`)\n", __func__, token, piece.c_str());
return;
} else {
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
@ -1353,7 +1424,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
for (const auto & trigger_pattern : grammar.trigger_patterns) {
auto start = trigger_pattern.find(grammar.trigger_buffer);
if (start != std::string::npos) {
grammar.awaiting_trigger = false;
// trial-accept replay tokens on copies of the grammar state;
// the completing token may contain text beyond the trigger
// that violates the grammar
auto trial_stacks = grammar.stacks;
auto trial_partial_utf8 = grammar.partial_utf8;
bool replay_ok = true;
// replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
@ -1365,13 +1441,31 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
size_t piece_len = tok_end - piece_start;
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
llama_grammar_accept_token(grammar, tok, tok_piece);
llama_grammar_stacks next_stacks;
llama_partial_utf8 next_partial;
if (!llama_grammar_try_accept_token(grammar.rules, tok, tok_piece, trial_stacks, trial_partial_utf8, next_stacks, next_partial)) {
LLAMA_LOG_WARN("%s: trigger replay failed on token %d (`%s`), treating as false trigger\n",
__func__, tok, tok_piece.c_str());
replay_ok = false;
break;
}
trial_stacks = std::move(next_stacks);
trial_partial_utf8 = next_partial;
}
auto constrained_str = grammar.trigger_buffer.substr(start);
if (!replay_ok) {
continue; // try next pattern
}
// replay succeeded, commit the new state
grammar.stacks = std::move(trial_stacks);
grammar.partial_utf8 = trial_partial_utf8;
grammar.awaiting_trigger = false;
LLAMA_LOG_DEBUG("%s: Grammar triggered on regex: '%s'\n", __func__, grammar.trigger_buffer.substr(start).c_str());
grammar.trigger_buffer.clear();
grammar.trigger_buffer_positions.clear();
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
return;
}
}
@ -1406,59 +1500,3 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
stacks_new.reserve(grammar.stacks.size());
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
continue;
}
const llama_grammar_element * pos = stack.back();
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
if (llama_grammar_match_token(pos, token)) {
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
new_stack.push_back(pos + 1);
}
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
}
} else {
llama_grammar_stacks current_stacks = {stack};
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks next_stacks;
for (const auto & cur_stack : current_stacks) {
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
}
current_stacks = std::move(next_stacks);
if (current_stacks.empty()) {
break;
}
}
for (auto & surviving_stack : current_stacks) {
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
stacks_new.emplace_back(surviving_stack);
}
}
}
}
grammar.stacks = std::move(stacks_new);
grammar.partial_utf8 = decoded.second;
if (grammar.stacks.empty()) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
}
}

View File

@ -2747,7 +2747,13 @@ private:
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
try {
common_sampler_accept(slot.smpl.get(), id, true);
} catch (const std::exception & e) {
send_error(slot, std::string("Grammar error: ") + e.what());
slot.release();
continue;
}
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_current = ggml_time_us();
@ -2791,7 +2797,19 @@ private:
const size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
std::vector<llama_token> ids;
try {
ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
} catch (const std::exception & e) {
send_error(slot, std::string("Grammar error: ") + e.what());
slot.i_batch_dft.clear();
slot.drafted.clear();
// n_draft is still valid: rollback speculative tokens and clean up KV cache
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
slot.release();
continue;
}
slot.i_batch_dft.clear();
slot.drafted.clear();