llama : add token matching support to llama-grammar (#17816)
* llama : add token support to llama-grammar * fix inverse token comment * refactor trigger_patterns to replay tokens instead of the entire string * add token documentation * fix test-llama-grammar * improve test cases for tokens
This commit is contained in:
parent
1d2a1ab73d
commit
e39502e74b
|
|
@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte
|
||||||
- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included)
|
- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included)
|
||||||
- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included)
|
- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included)
|
||||||
|
|
||||||
|
## Tokens
|
||||||
|
|
||||||
|
Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `<think>` or `</think>`).
|
||||||
|
|
||||||
|
Tokens can be specified in two ways:
|
||||||
|
|
||||||
|
1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000.
|
||||||
|
|
||||||
|
2. **Token string**: Use angle brackets with the token text directly: `<token>`. For example, `<think>` will match the token whose text is exactly `<think>`. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse.
|
||||||
|
|
||||||
|
You can negate token matches using the `!` prefix: `!<[1000]>` or `!<think>` matches any token *except* the specified one.
|
||||||
|
|
||||||
|
```
|
||||||
|
# Match a thinking block: <think>...</think>
|
||||||
|
# Using token strings (requires these to be single tokens in the vocab)
|
||||||
|
root ::= <think> thinking </think> .*
|
||||||
|
thinking ::= !</think>*
|
||||||
|
|
||||||
|
# Equivalent grammar using explicit token IDs
|
||||||
|
# Assumes token 1000 = <think>, token 1001 = </think>
|
||||||
|
root ::= <[1000]> thinking <[1001]> .*
|
||||||
|
thinking ::= !<[1001]>*
|
||||||
|
```
|
||||||
|
|
||||||
## Comments and newlines
|
## Comments and newlines
|
||||||
|
|
||||||
Comments can be specified with `#`:
|
Comments can be specified with `#`:
|
||||||
|
|
|
||||||
|
|
@ -181,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
if (*pos != '<') {
|
||||||
|
throw std::runtime_error(std::string("expecting '<' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
// Parse <[id]>
|
||||||
|
if (*pos == '[') {
|
||||||
|
pos++;
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = int_end;
|
||||||
|
if (*pos != ']') {
|
||||||
|
throw std::runtime_error(std::string("expecting ']' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
if (*pos != '>') {
|
||||||
|
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
return std::make_pair(token_id, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab == nullptr) {
|
||||||
|
throw std::runtime_error(std::string("no vocab to parse token at ") + src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse <token> and tokenize to obtain the token id
|
||||||
|
while (*pos != 0 && *pos != '>') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (*pos != '>') {
|
||||||
|
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
llama_token tokens[2];
|
||||||
|
int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
|
||||||
|
if (n_tokens != 1) {
|
||||||
|
// must tokenize to exactly 1 token
|
||||||
|
throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
|
||||||
|
}
|
||||||
|
return std::make_pair(tokens[0], pos);
|
||||||
|
}
|
||||||
|
|
||||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||||
if (0x20 <= c && c <= 0x7f) {
|
if (0x20 <= c && c <= 0x7f) {
|
||||||
fprintf(file, "%c", static_cast<char>(c));
|
fprintf(file, "%c", static_cast<char>(c));
|
||||||
|
|
@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
|
||||||
}
|
}
|
||||||
switch (elem.type) {
|
switch (elem.type) {
|
||||||
case LLAMA_GRETYPE_END:
|
case LLAMA_GRETYPE_END:
|
||||||
|
|
@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||||
print_grammar_char(file, elem.value);
|
print_grammar_char(file, elem.value);
|
||||||
fprintf(file, "\") ");
|
fprintf(file, "\") ");
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
|
fprintf(file, "!");
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(file, "\n");
|
fprintf(file, "\n");
|
||||||
|
|
@ -284,6 +343,17 @@ static void print_rule(
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
fprintf(file, ".");
|
fprintf(file, ".");
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
|
fprintf(file, "!");
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
if (is_char_element(elem)) {
|
if (is_char_element(elem)) {
|
||||||
switch (rule[i + 1].type) {
|
switch (rule[i + 1].type) {
|
||||||
|
|
@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '<' || *pos == '!') { // token
|
||||||
|
auto type = LLAMA_GRETYPE_TOKEN;
|
||||||
|
if (*pos == '!') { // token inverse
|
||||||
|
type = LLAMA_GRETYPE_TOKEN_NOT;
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
auto token_pair = parse_token(vocab, pos);
|
||||||
|
const char * token_end = token_pair.second;
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({type, token_pair.first});
|
||||||
|
pos = parse_space(token_end, is_nested);
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
const char * name_end = parse_name(pos);
|
const char * name_end = parse_name(pos);
|
||||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
|
|
@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char(
|
||||||
return !is_positive_char;
|
return !is_positive_char;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns true iff token matches the rule at pos (regular or inverse)
|
||||||
|
// asserts that pos is pointing to a token element
|
||||||
|
static bool llama_grammar_match_token(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const llama_token token) {
|
||||||
|
GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN) {
|
||||||
|
return pos->value == static_cast<uint32_t>(token);
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
return pos->value != static_cast<uint32_t>(token);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
|
|
@ -738,6 +834,8 @@ static void llama_grammar_advance_stack(
|
||||||
case LLAMA_GRETYPE_CHAR:
|
case LLAMA_GRETYPE_CHAR:
|
||||||
case LLAMA_GRETYPE_CHAR_NOT:
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
// only add the stack if it's not a duplicate of one we already have
|
// only add the stack if it's not a duplicate of one we already have
|
||||||
new_stacks.emplace_back(stack);
|
new_stacks.emplace_back(stack);
|
||||||
|
|
@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
||||||
return grammar->stacks;
|
return grammar->stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_grammar_accept_chr(
|
||||||
|
struct llama_grammar & grammar,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
uint32_t chr,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
// ignore if this turns into a token
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(pos, chr);
|
||||||
|
if (match.first) {
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||||
|
new_stack.push_back(match.second);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
||||||
llama_grammar_stacks stacks_new;
|
llama_grammar_stacks stacks_new;
|
||||||
stacks_new.reserve(grammar->stacks.size());
|
stacks_new.reserve(grammar->stacks.size());
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
|
||||||
if (match.first) {
|
|
||||||
const llama_grammar_element * pos = match.second;
|
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
|
||||||
new_stack.push_back(pos);
|
|
||||||
}
|
|
||||||
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar->stacks = std::move(stacks_new);
|
grammar->stacks = std::move(stacks_new);
|
||||||
|
|
@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
||||||
const llama_grammar_element * stack_pos = stack.back();
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
|
// if the top of the stack is a token rule, then we only need to check the token id
|
||||||
|
if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
for (const auto & tok : candidates) {
|
||||||
|
if (*tok.code_points == 0) {
|
||||||
|
// reached the end of a token consumed by char rules, reject iff it ended
|
||||||
|
// in a partial response
|
||||||
|
if (tok.partial_utf8.n_remain != 0) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
} else if (!llama_grammar_match_token(stack_pos, tok.id)) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
llama_grammar_candidates next_candidates;
|
llama_grammar_candidates next_candidates;
|
||||||
next_candidates.reserve(candidates.size());
|
next_candidates.reserve(candidates.size());
|
||||||
|
|
||||||
|
|
@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
rejects.push_back(tok);
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
|
||||||
} else {
|
} else {
|
||||||
rejects.push_back(tok);
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
|
|
@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
||||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
for (const auto & tok : next_rejects) {
|
for (const auto & tok : next_rejects) {
|
||||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
|
||||||
}
|
}
|
||||||
|
|
||||||
return rejects;
|
return rejects;
|
||||||
|
|
@ -976,6 +1102,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
/* .lazy = */ false,
|
/* .lazy = */ false,
|
||||||
/* .awaiting_trigger = */ false,
|
/* .awaiting_trigger = */ false,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
|
/* .trigger_buffer_positions = */ {},
|
||||||
/* .trigger_tokens = */ {},
|
/* .trigger_tokens = */ {},
|
||||||
/* .trigger_patterns = */ {},
|
/* .trigger_patterns = */ {},
|
||||||
};
|
};
|
||||||
|
|
@ -990,7 +1117,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
size_t num_trigger_patterns,
|
size_t num_trigger_patterns,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens) {
|
size_t num_trigger_tokens) {
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser(vocab);
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
// rules will be empty (default) if there are parse errors
|
// rules will be empty (default) if there are parse errors
|
||||||
|
|
@ -1081,6 +1208,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
/* .lazy = */ lazy,
|
/* .lazy = */ lazy,
|
||||||
/* .awaiting_trigger = */ lazy,
|
/* .awaiting_trigger = */ lazy,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
|
/* .trigger_buffer_positions = */ {},
|
||||||
std::move(vec_trigger_tokens),
|
std::move(vec_trigger_tokens),
|
||||||
std::move(vec_trigger_patterns),
|
std::move(vec_trigger_patterns),
|
||||||
};
|
};
|
||||||
|
|
@ -1103,6 +1231,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||||
grammar.lazy,
|
grammar.lazy,
|
||||||
grammar.awaiting_trigger,
|
grammar.awaiting_trigger,
|
||||||
grammar.trigger_buffer,
|
grammar.trigger_buffer,
|
||||||
|
grammar.trigger_buffer_positions,
|
||||||
grammar.trigger_tokens,
|
grammar.trigger_tokens,
|
||||||
grammar.trigger_patterns,
|
grammar.trigger_patterns,
|
||||||
};
|
};
|
||||||
|
|
@ -1156,7 +1285,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1175,10 +1304,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_token(grammar, token, piece);
|
||||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
|
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||||
|
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||||
grammar.trigger_buffer += piece;
|
grammar.trigger_buffer += piece;
|
||||||
|
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
|
@ -1196,10 +1327,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
if (start == std::string::npos) {
|
if (start == std::string::npos) {
|
||||||
start = match.position(0);
|
start = match.position(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replay tokens that overlap with [start, end)
|
||||||
|
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||||
|
auto [tok_start, tok_end] = tok_pos;
|
||||||
|
if (tok_end <= start) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
|
||||||
|
size_t piece_len = tok_end - piece_start;
|
||||||
|
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
|
||||||
|
llama_grammar_accept_token(grammar, tok, tok_piece);
|
||||||
|
}
|
||||||
|
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
grammar.trigger_buffer_positions.clear();
|
||||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -1218,7 +1362,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_token(grammar, token, piece);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||||
|
|
@ -1235,3 +1379,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||||
|
// Note terminating 0 in decoded string
|
||||||
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
llama_grammar_stacks stacks_new;
|
||||||
|
stacks_new.reserve(grammar.stacks.size());
|
||||||
|
|
||||||
|
for (const auto & stack : grammar.stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
if (llama_grammar_match_token(pos, token)) {
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
|
new_stack.push_back(pos + 1);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
llama_grammar_stacks current_stacks = {stack};
|
||||||
|
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
llama_grammar_stacks next_stacks;
|
||||||
|
|
||||||
|
for (const auto & cur_stack : current_stacks) {
|
||||||
|
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
|
||||||
|
}
|
||||||
|
|
||||||
|
current_stacks = std::move(next_stacks);
|
||||||
|
if (current_stacks.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & surviving_stack : current_stacks) {
|
||||||
|
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
|
||||||
|
stacks_new.emplace_back(surviving_stack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar.stacks = std::move(stacks_new);
|
||||||
|
grammar.partial_utf8 = decoded.second;
|
||||||
|
|
||||||
|
if (grammar.stacks.empty()) {
|
||||||
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,17 @@ enum llama_gretype {
|
||||||
|
|
||||||
// any character (.)
|
// any character (.)
|
||||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||||
|
|
||||||
|
// terminal element: token (<[token-id]>)
|
||||||
|
LLAMA_GRETYPE_TOKEN = 8,
|
||||||
|
|
||||||
|
// inverse token (!<[token-id]>)
|
||||||
|
LLAMA_GRETYPE_TOKEN_NOT = 9,
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_grammar_element {
|
typedef struct llama_grammar_element {
|
||||||
enum llama_gretype type;
|
enum llama_gretype type;
|
||||||
uint32_t value; // Unicode code point or rule ID
|
uint32_t value; // Unicode code point, rule ID, or token ID
|
||||||
} llama_grammar_element;
|
} llama_grammar_element;
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
struct llama_partial_utf8 {
|
||||||
|
|
@ -52,6 +58,7 @@ struct llama_grammar_candidate {
|
||||||
size_t index;
|
size_t index;
|
||||||
const uint32_t * code_points;
|
const uint32_t * code_points;
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
llama_token id;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||||
|
|
@ -77,10 +84,13 @@ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_candidates & candidates);
|
const llama_grammar_candidates & candidates);
|
||||||
|
|
||||||
struct llama_grammar_parser {
|
struct llama_grammar_parser {
|
||||||
|
const llama_vocab * vocab;
|
||||||
std::map<std::string, uint32_t> symbol_ids;
|
std::map<std::string, uint32_t> symbol_ids;
|
||||||
|
|
||||||
llama_grammar_rules rules;
|
llama_grammar_rules rules;
|
||||||
|
|
||||||
|
llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
|
||||||
|
|
||||||
llama_grammar_stack c_rules() const;
|
llama_grammar_stack c_rules() const;
|
||||||
|
|
||||||
uint32_t get_symbol_id(const char * src, size_t len);
|
uint32_t get_symbol_id(const char * src, size_t len);
|
||||||
|
|
@ -112,6 +122,9 @@ struct llama_grammar_trigger_pattern {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
|
// maintain a list of llama_tokens and their positions in the trigger_buffer
|
||||||
|
using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
|
||||||
|
|
||||||
// note: allow null vocab for testing (not great)
|
// note: allow null vocab for testing (not great)
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
|
|
||||||
|
|
@ -127,6 +140,7 @@ struct llama_grammar {
|
||||||
bool lazy = false;
|
bool lazy = false;
|
||||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||||
|
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
|
||||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||||
std::vector<llama_grammar_trigger_pattern>
|
std::vector<llama_grammar_trigger_pattern>
|
||||||
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
||||||
|
|
@ -171,3 +185,8 @@ void llama_grammar_accept_impl(
|
||||||
void llama_grammar_accept_str(
|
void llama_grammar_accept_str(
|
||||||
struct llama_grammar & grammar,
|
struct llama_grammar & grammar,
|
||||||
const std::string & piece);
|
const std::string & piece);
|
||||||
|
|
||||||
|
void llama_grammar_accept_token(
|
||||||
|
struct llama_grammar & grammar,
|
||||||
|
llama_token token,
|
||||||
|
const std::string & piece);
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||||
return grammar_fails;
|
return grammar_fails;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct token_and_piece {
|
||||||
|
llama_token token;
|
||||||
|
std::string piece;
|
||||||
|
};
|
||||||
|
|
||||||
|
// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order.
|
||||||
|
static std::string token(llama_token id) {
|
||||||
|
return std::string{
|
||||||
|
static_cast<char>(0xff),
|
||||||
|
static_cast<char>((id >> 24) & 0xff),
|
||||||
|
static_cast<char>((id >> 16) & 0xff),
|
||||||
|
static_cast<char>((id >> 8) & 0xff),
|
||||||
|
static_cast<char>(id & 0xff)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse_tokens() parses the token encodes above and UTF-8 text.
|
||||||
|
static std::vector<token_and_piece> parse_tokens(const std::string & input) {
|
||||||
|
std::vector<token_and_piece> result;
|
||||||
|
result.reserve(input.size());
|
||||||
|
size_t offset = 0;
|
||||||
|
while (offset < input.size()) {
|
||||||
|
try {
|
||||||
|
if (static_cast<unsigned char>(input[offset]) == 0xff) {
|
||||||
|
if (offset + 5 > input.size()) {
|
||||||
|
throw std::runtime_error("not enough bytes for token id");
|
||||||
|
}
|
||||||
|
uint32_t val =
|
||||||
|
(static_cast<unsigned char>(input[offset + 1]) << 24) |
|
||||||
|
(static_cast<unsigned char>(input[offset + 2]) << 16) |
|
||||||
|
(static_cast<unsigned char>(input[offset + 3]) << 8) |
|
||||||
|
(static_cast<unsigned char>(input[offset + 4]));
|
||||||
|
auto piece = "<[" + std::to_string(val) + "]>";
|
||||||
|
result.push_back({static_cast<llama_token>(val), piece});
|
||||||
|
offset += 5;
|
||||||
|
} else {
|
||||||
|
uint32_t cpt = unicode_cpt_from_utf8(input, offset);
|
||||||
|
result.push_back({0, unicode_cpt_to_utf8(cpt)});
|
||||||
|
}
|
||||||
|
} catch (const std::invalid_argument & /*ex*/) {
|
||||||
|
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
|
||||||
|
++offset;
|
||||||
|
result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||||
const auto cpts = unicode_cpts_from_utf8(input);
|
const auto parsed = parse_tokens(input);
|
||||||
|
|
||||||
auto & stacks_cur = llama_grammar_get_stacks(grammar);
|
auto & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & in : parsed) {
|
||||||
llama_grammar_accept(grammar, cpt);
|
try {
|
||||||
|
llama_grammar_accept_token(*grammar, in.token, in.piece);
|
||||||
|
} catch (const std::runtime_error & /*e*/) {
|
||||||
|
// normally this shouldn't get hit because of llama_grammar_apply
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (stacks_cur.empty()) {
|
if (stacks_cur.empty()) {
|
||||||
// no stacks means that the grammar failed to match at this point
|
// no stacks means that the grammar failed to match at this point
|
||||||
|
|
@ -426,6 +479,30 @@ static void test_simple_grammar() {
|
||||||
"12a45",
|
"12a45",
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Test case for a simple grammar with tokens
|
||||||
|
test_grammar(
|
||||||
|
"simple grammar with tokens",
|
||||||
|
R"""(
|
||||||
|
root ::= <[10]> content <[11]>
|
||||||
|
content ::= (!<[11]>)*)""",
|
||||||
|
// Passing strings
|
||||||
|
{
|
||||||
|
token(10) + "hello world" + token(11),
|
||||||
|
token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11),
|
||||||
|
token(10) + token(11),
|
||||||
|
token(10) + token(12) + token(13) + token(14) + token(15) + token(11),
|
||||||
|
token(10) + "a" + token(11),
|
||||||
|
},
|
||||||
|
// Failing strings
|
||||||
|
{
|
||||||
|
token(10) + "missing end token",
|
||||||
|
token(10),
|
||||||
|
"missing start token" + token(11),
|
||||||
|
token(10) + token(11) + token(11), // double end token
|
||||||
|
token(11) + "wrong order" + token(10),
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_complex_grammar() {
|
static void test_complex_grammar() {
|
||||||
|
|
@ -487,6 +564,34 @@ static void test_complex_grammar() {
|
||||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
|
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Test case for a more complex grammar with tokens
|
||||||
|
test_grammar(
|
||||||
|
"complex grammar with tokens",
|
||||||
|
R"""(
|
||||||
|
root ::= reasoning+ content tool-call*
|
||||||
|
reasoning ::= <[10]> (!<[11]>)* <[11]>
|
||||||
|
content ::= <[20]> (!<[21]>)* <[21]>
|
||||||
|
tool-call ::= <[12]> name <[13]> args <[14]>
|
||||||
|
name ::= (!<[13]>)+
|
||||||
|
args ::= (!<[14]>)*)""",
|
||||||
|
// Passing strings
|
||||||
|
{
|
||||||
|
token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14),
|
||||||
|
token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14),
|
||||||
|
token(10) + token(11) + token(20) + "content" + token(21),
|
||||||
|
token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14),
|
||||||
|
token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14),
|
||||||
|
},
|
||||||
|
// Failing strings
|
||||||
|
{
|
||||||
|
token(20) + "content only" + token(21),
|
||||||
|
token(10) + "no closing reasoning",
|
||||||
|
token(10) + token(11) + token(20) + "no closing content",
|
||||||
|
token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool",
|
||||||
|
token(10) + token(11) + token(11) + token(20) + token(21),
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_special_chars() {
|
static void test_special_chars() {
|
||||||
|
|
|
||||||
|
|
@ -515,5 +515,19 @@ int main()
|
||||||
{LLAMA_GRETYPE_END, 0},
|
{LLAMA_GRETYPE_END, 0},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// <[1000]> = "<think>"
|
||||||
|
// <[1001]> = "</think>"
|
||||||
|
verify_parsing(R"""(
|
||||||
|
root ::= <[1000]> !<[1001]> <[1001]>
|
||||||
|
)""", {
|
||||||
|
{"root", 0}
|
||||||
|
}, {
|
||||||
|
// root (index 0)
|
||||||
|
{LLAMA_GRETYPE_TOKEN, 1000},
|
||||||
|
{LLAMA_GRETYPE_TOKEN_NOT, 1001},
|
||||||
|
{LLAMA_GRETYPE_TOKEN, 1001},
|
||||||
|
{LLAMA_GRETYPE_END, 0},
|
||||||
|
});
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -202,7 +202,7 @@ int main()
|
||||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||||
cp[0] = 37 + i;
|
cp[0] = 37 + i;
|
||||||
cp[1] = 0;
|
cp[1] = 0;
|
||||||
next_candidates[i] = {i, cp, {}};
|
next_candidates[i] = {i, cp, {}, 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue