diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2d55070cec..9fea914a22 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -968,7 +968,7 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) } static void llama_grammar_accept_chr( - struct llama_grammar & grammar, + const llama_grammar_rules & rules, const llama_grammar_stack & stack, uint32_t chr, llama_grammar_stacks & new_stacks) { @@ -989,7 +989,7 @@ static void llama_grammar_accept_chr( if (!llama_grammar_is_end_of_sequence(match.second)) { new_stack.push_back(match.second); } - llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, new_stacks); } } @@ -998,7 +998,7 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { stacks_new.reserve(grammar->stacks.size()); for (const auto & stack : grammar->stacks) { - llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); + llama_grammar_accept_chr(grammar->rules, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -1333,6 +1333,77 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } } +// compute new stacks after accepting a token, without mutating the grammar +// returns true if the token is valid (non-empty stacks) +static bool llama_grammar_try_accept_token( + const llama_grammar_rules & rules, + llama_token token, + const std::string & piece, + const llama_grammar_stacks & stacks_cur, + const llama_partial_utf8 & partial_utf8_cur, + llama_grammar_stacks & new_stacks, + llama_partial_utf8 & new_partial_utf8) { + const auto decoded = decode_utf8(piece, partial_utf8_cur); + const auto & code_points = decoded.first; + + new_stacks.clear(); + new_stacks.reserve(stacks_cur.size()); + + for (const auto & stack : stacks_cur) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(rules, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(new_stacks.begin(), new_stacks.end(), surviving_stack) == new_stacks.end()) { + new_stacks.emplace_back(surviving_stack); + } + } + } + } + + new_partial_utf8 = decoded.second; + return !new_stacks.empty(); +} + +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + llama_grammar_stacks stacks_new; + llama_partial_utf8 partial_utf8_new; + + if (!llama_grammar_try_accept_token(grammar.rules, token, piece, grammar.stacks, grammar.partial_utf8, stacks_new, partial_utf8_new)) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = partial_utf8_new; +} + void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); @@ -1343,7 +1414,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); llama_grammar_accept_token(grammar, token, piece); - LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + LLAMA_LOG_DEBUG("%s: Grammar triggered on token %u (`%s`)\n", __func__, token, piece.c_str()); return; } else { auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); @@ -1353,7 +1424,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token for (const auto & trigger_pattern : grammar.trigger_patterns) { auto start = trigger_pattern.find(grammar.trigger_buffer); if (start != std::string::npos) { - grammar.awaiting_trigger = false; + // trial-accept replay tokens on copies of the grammar state; + // the completing token may contain text beyond the trigger + // that violates the grammar + auto trial_stacks = grammar.stacks; + auto trial_partial_utf8 = grammar.partial_utf8; + bool replay_ok = true; // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { @@ -1365,13 +1441,31 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces size_t piece_len = tok_end - piece_start; auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); - llama_grammar_accept_token(grammar, tok, tok_piece); + + llama_grammar_stacks next_stacks; + llama_partial_utf8 next_partial; + if (!llama_grammar_try_accept_token(grammar.rules, tok, tok_piece, trial_stacks, trial_partial_utf8, next_stacks, next_partial)) { + LLAMA_LOG_WARN("%s: trigger replay failed on token %d (`%s`), treating as false trigger\n", + __func__, tok, tok_piece.c_str()); + replay_ok = false; + break; + } + trial_stacks = std::move(next_stacks); + trial_partial_utf8 = next_partial; } - auto constrained_str = grammar.trigger_buffer.substr(start); + if (!replay_ok) { + continue; // try next pattern + } + + // replay succeeded, commit the new state + grammar.stacks = std::move(trial_stacks); + grammar.partial_utf8 = trial_partial_utf8; + grammar.awaiting_trigger = false; + + LLAMA_LOG_DEBUG("%s: Grammar triggered on regex: '%s'\n", __func__, grammar.trigger_buffer.substr(start).c_str()); grammar.trigger_buffer.clear(); grammar.trigger_buffer_positions.clear(); - LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } } @@ -1406,59 +1500,3 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } - -void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { - // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar.partial_utf8); - const auto & code_points = decoded.first; - - llama_grammar_stacks stacks_new; - stacks_new.reserve(grammar.stacks.size()); - - for (const auto & stack : grammar.stacks) { - if (stack.empty()) { - continue; - } - - const llama_grammar_element * pos = stack.back(); - - if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { - if (llama_grammar_match_token(pos, token)) { - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos + 1)) { - new_stack.push_back(pos + 1); - } - llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); - } - } else { - llama_grammar_stacks current_stacks = {stack}; - - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_stacks next_stacks; - - for (const auto & cur_stack : current_stacks) { - llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); - } - - current_stacks = std::move(next_stacks); - if (current_stacks.empty()) { - break; - } - } - - for (auto & surviving_stack : current_stacks) { - if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { - stacks_new.emplace_back(surviving_stack); - } - } - } - } - - grammar.stacks = std::move(stacks_new); - grammar.partial_utf8 = decoded.second; - - if (grammar.stacks.empty()) { - throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); - } -} - diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ceafcac179..8cac58110f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2747,7 +2747,13 @@ private: slot.i_batch = -1; - common_sampler_accept(slot.smpl.get(), id, true); + try { + common_sampler_accept(slot.smpl.get(), id, true); + } catch (const std::exception & e) { + send_error(slot, std::string("Grammar error: ") + e.what()); + slot.release(); + continue; + } // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement const int64_t t_current = ggml_time_us(); @@ -2791,7 +2797,19 @@ private: const size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + std::vector ids; + try { + ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + } catch (const std::exception & e) { + send_error(slot, std::string("Grammar error: ") + e.what()); + slot.i_batch_dft.clear(); + slot.drafted.clear(); + // n_draft is still valid: rollback speculative tokens and clean up KV cache + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + slot.release(); + continue; + } slot.i_batch_dft.clear(); slot.drafted.clear();