Merge 3637bf51cd into 05fa625eac
This commit is contained in:
commit
ceaf729db0
|
|
@ -968,7 +968,7 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|||
}
|
||||
|
||||
static void llama_grammar_accept_chr(
|
||||
struct llama_grammar & grammar,
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
|
|
@ -989,7 +989,7 @@ static void llama_grammar_accept_chr(
|
|||
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||
new_stack.push_back(match.second);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -998,7 +998,7 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
|||
stacks_new.reserve(grammar->stacks.size());
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
|
||||
llama_grammar_accept_chr(grammar->rules, stack, chr, stacks_new);
|
||||
}
|
||||
|
||||
grammar->stacks = std::move(stacks_new);
|
||||
|
|
@ -1333,6 +1333,77 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||
}
|
||||
}
|
||||
|
||||
// compute new stacks after accepting a token, without mutating the grammar
|
||||
// returns true if the token is valid (non-empty stacks)
|
||||
static bool llama_grammar_try_accept_token(
|
||||
const llama_grammar_rules & rules,
|
||||
llama_token token,
|
||||
const std::string & piece,
|
||||
const llama_grammar_stacks & stacks_cur,
|
||||
const llama_partial_utf8 & partial_utf8_cur,
|
||||
llama_grammar_stacks & new_stacks,
|
||||
llama_partial_utf8 & new_partial_utf8) {
|
||||
const auto decoded = decode_utf8(piece, partial_utf8_cur);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
new_stacks.clear();
|
||||
new_stacks.reserve(stacks_cur.size());
|
||||
|
||||
for (const auto & stack : stacks_cur) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
if (llama_grammar_match_token(pos, token)) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_stacks next_stacks;
|
||||
|
||||
for (const auto & cur_stack : current_stacks) {
|
||||
llama_grammar_accept_chr(rules, cur_stack, *it, next_stacks);
|
||||
}
|
||||
|
||||
current_stacks = std::move(next_stacks);
|
||||
if (current_stacks.empty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & surviving_stack : current_stacks) {
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), surviving_stack) == new_stacks.end()) {
|
||||
new_stacks.emplace_back(surviving_stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_partial_utf8 = decoded.second;
|
||||
return !new_stacks.empty();
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||
llama_grammar_stacks stacks_new;
|
||||
llama_partial_utf8 partial_utf8_new;
|
||||
|
||||
if (!llama_grammar_try_accept_token(grammar.rules, token, piece, grammar.stacks, grammar.partial_utf8, stacks_new, partial_utf8_new)) {
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||
}
|
||||
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
grammar.partial_utf8 = partial_utf8_new;
|
||||
}
|
||||
|
||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
|
|
@ -1343,7 +1414,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_token(grammar, token, piece);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||
LLAMA_LOG_DEBUG("%s: Grammar triggered on token %u (`%s`)\n", __func__, token, piece.c_str());
|
||||
return;
|
||||
} else {
|
||||
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||
|
|
@ -1353,7 +1424,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
auto start = trigger_pattern.find(grammar.trigger_buffer);
|
||||
if (start != std::string::npos) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// trial-accept replay tokens on copies of the grammar state;
|
||||
// the completing token may contain text beyond the trigger
|
||||
// that violates the grammar
|
||||
auto trial_stacks = grammar.stacks;
|
||||
auto trial_partial_utf8 = grammar.partial_utf8;
|
||||
bool replay_ok = true;
|
||||
|
||||
// replay tokens that overlap with [start, end)
|
||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||
|
|
@ -1365,13 +1441,31 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
|
||||
size_t piece_len = tok_end - piece_start;
|
||||
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
|
||||
llama_grammar_accept_token(grammar, tok, tok_piece);
|
||||
|
||||
llama_grammar_stacks next_stacks;
|
||||
llama_partial_utf8 next_partial;
|
||||
if (!llama_grammar_try_accept_token(grammar.rules, tok, tok_piece, trial_stacks, trial_partial_utf8, next_stacks, next_partial)) {
|
||||
LLAMA_LOG_WARN("%s: trigger replay failed on token %d (`%s`), treating as false trigger\n",
|
||||
__func__, tok, tok_piece.c_str());
|
||||
replay_ok = false;
|
||||
break;
|
||||
}
|
||||
trial_stacks = std::move(next_stacks);
|
||||
trial_partial_utf8 = next_partial;
|
||||
}
|
||||
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
if (!replay_ok) {
|
||||
continue; // try next pattern
|
||||
}
|
||||
|
||||
// replay succeeded, commit the new state
|
||||
grammar.stacks = std::move(trial_stacks);
|
||||
grammar.partial_utf8 = trial_partial_utf8;
|
||||
grammar.awaiting_trigger = false;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: Grammar triggered on regex: '%s'\n", __func__, grammar.trigger_buffer.substr(start).c_str());
|
||||
grammar.trigger_buffer.clear();
|
||||
grammar.trigger_buffer_positions.clear();
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1406,59 +1500,3 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
|||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar.stacks.size());
|
||||
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
if (llama_grammar_match_token(pos, token)) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_stacks next_stacks;
|
||||
|
||||
for (const auto & cur_stack : current_stacks) {
|
||||
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
|
||||
}
|
||||
|
||||
current_stacks = std::move(next_stacks);
|
||||
if (current_stacks.empty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & surviving_stack : current_stacks) {
|
||||
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
|
||||
stacks_new.emplace_back(surviving_stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
|
||||
if (grammar.stacks.empty()) {
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2747,7 +2747,13 @@ private:
|
|||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
try {
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
} catch (const std::exception & e) {
|
||||
send_error(slot, std::string("Grammar error: ") + e.what());
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
|
||||
const int64_t t_current = ggml_time_us();
|
||||
|
|
@ -2791,7 +2797,19 @@ private:
|
|||
const size_t n_draft = slot.drafted.size();
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
||||
std::vector<llama_token> ids;
|
||||
try {
|
||||
ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
||||
} catch (const std::exception & e) {
|
||||
send_error(slot, std::string("Grammar error: ") + e.what());
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
// n_draft is still valid: rollback speculative tokens and clean up KV cache
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue