common : restore grammar-based rejection sampling (#18137)

* common : restart grammar-based rejection sampling

* sampling : allow null samplers
This commit is contained in:
Georgi Gerganov 2025-12-17 19:46:00 +02:00 committed by GitHub
parent a2c199e479
commit 4301e27319
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 43 deletions

View File

@ -104,10 +104,9 @@ struct ring_buffer {
struct common_sampler { struct common_sampler {
common_params_sampling params; common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain; struct llama_sampler * chain;
bool grammar;
ring_buffer<llama_token> prev; ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams); llama_sampler * chain = llama_sampler_chain_init(lparams);
bool grammar = false;
std::vector<llama_sampler *> samplers; std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str())); grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
grammar = true;
#else #else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
if (!params.grammar.empty()) { if (!params.grammar.empty()) {
if (params.grammar_lazy) { if (params.grammar_lazy) {
samplers.push_back( grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())); trigger_tokens.data(), trigger_tokens.size());
} else { } else {
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root")); grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
} }
grammar = true;
} }
} }
@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ chain, /* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},
/* .cur_p = */ {}, /* .cur_p = */ {},
@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
void common_sampler_free(struct common_sampler * gsmpl) { void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) { if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain); llama_sampler_free(gsmpl->chain);
delete gsmpl; delete gsmpl;
@ -324,24 +320,11 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm(); const auto tm = gsmpl->tm();
if (gsmpl->grammar) { if (gsmpl->grmr && accept_grammar) {
const int n_smpl = llama_sampler_chain_n(gsmpl->chain); llama_sampler_accept(gsmpl->grmr, token);
}
for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token); llama_sampler_accept(gsmpl->chain, token);
}
gsmpl->prev.push_back(token); gsmpl->prev.push_back(token);
} }
@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler { return new common_sampler {
/* .params = */ gsmpl->params, /* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain), /* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev, /* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur, /* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p, /* .cur_p = */ gsmpl->cur_p,
@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain; return gsmpl->chain;
} }
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx); llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL; llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain; auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
llama_sampler_apply(grmr, &single_token_data_array);
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p); llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id; return id;
} }
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) { std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result; std::vector<llama_token> result;
@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0; size_t i = 0;
for (; i < draft.size(); i++) { for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true); common_sampler_accept(gsmpl, id, true);
@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
} }
if (i == draft.size()) { if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true); common_sampler_accept(gsmpl, id, true);
@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result; return result;
} }
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) { std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1); std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) { for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i; idxs[i] = i;
} }
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
} }
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

View File

@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// - check if the token fits the grammar (if any) // - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path) // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
// //
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); // if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample // generalized version of common_sampler_sample
// //
@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// //
// returns at least 1 token, up to idxs.size() // returns at least 1 token, up to idxs.size()
// //
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ] // assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

View File

@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) { for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch); common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0); common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true); const auto * cur_p = common_sampler_get_candidates(smpl, true);

View File

@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false; bool accept = false;
if (params.sampling.temp > 0) { if (params.sampling.temp > 0) {
// stochastic verification // stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true); auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue; continue;
} }
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true); const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

View File

@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
} }
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
if (!smpl) {
return;
}
if (smpl->iface->accept) { if (smpl->iface->accept) {
smpl->iface->accept(smpl, token); smpl->iface->accept(smpl, token);
} }
} }
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
if (!smpl) {
return;
}
GGML_ASSERT(smpl->iface->apply); GGML_ASSERT(smpl->iface->apply);
smpl->iface->apply(smpl, cur_p); smpl->iface->apply(smpl, cur_p);
} }
void llama_sampler_reset(struct llama_sampler * smpl) { void llama_sampler_reset(struct llama_sampler * smpl) {
if (!smpl) {
return;
}
if (smpl->iface->reset) { if (smpl->iface->reset) {
smpl->iface->reset(smpl); smpl->iface->reset(smpl);
} }
} }
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
if (!smpl) {
return nullptr;
}
if (smpl->iface->clone) { if (smpl->iface->clone) {
return smpl->iface->clone(smpl); return smpl->iface->clone(smpl);
} }