common : restore grammar-based rejection sampling (#18137)
* common : restart grammar-based rejection sampling * sampling : allow null samplers
This commit is contained in:
parent
a2c199e479
commit
4301e27319
|
|
@ -104,10 +104,9 @@ struct ring_buffer {
|
||||||
struct common_sampler {
|
struct common_sampler {
|
||||||
common_params_sampling params;
|
common_params_sampling params;
|
||||||
|
|
||||||
|
struct llama_sampler * grmr;
|
||||||
struct llama_sampler * chain;
|
struct llama_sampler * chain;
|
||||||
|
|
||||||
bool grammar;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
llama_sampler * grmr = nullptr;
|
||||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||||
|
|
||||||
bool grammar = false;
|
|
||||||
std::vector<llama_sampler *> samplers;
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
||||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||||
grammar = true;
|
|
||||||
#else
|
#else
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
|
@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
if (!params.grammar.empty()) {
|
||||||
if (params.grammar_lazy) {
|
if (params.grammar_lazy) {
|
||||||
samplers.push_back(
|
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||||
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
|
||||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
trigger_tokens.data(), trigger_tokens.size()));
|
trigger_tokens.data(), trigger_tokens.size());
|
||||||
} else {
|
} else {
|
||||||
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
|
/* .grmr = */ grmr,
|
||||||
/* .chain = */ chain,
|
/* .chain = */ chain,
|
||||||
/* .grammar = */ grammar,
|
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
/* .cur_p = */ {},
|
/* .cur_p = */ {},
|
||||||
|
|
@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
|
llama_sampler_free(gsmpl->grmr);
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
|
||||||
delete gsmpl;
|
delete gsmpl;
|
||||||
|
|
@ -324,24 +320,11 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
const auto tm = gsmpl->tm();
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
if (gsmpl->grammar) {
|
if (gsmpl->grmr && accept_grammar) {
|
||||||
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_smpl; i++) {
|
|
||||||
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
||||||
|
|
||||||
// the grammar sampler is always the first one
|
|
||||||
if (i == 0) {
|
|
||||||
if (accept_grammar) {
|
|
||||||
llama_sampler_accept(smpl, token);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
llama_sampler_accept(smpl, token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
llama_sampler_accept(gsmpl->chain, token);
|
llama_sampler_accept(gsmpl->chain, token);
|
||||||
}
|
|
||||||
|
|
||||||
gsmpl->prev.push_back(token);
|
gsmpl->prev.push_back(token);
|
||||||
}
|
}
|
||||||
|
|
@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
return new common_sampler {
|
return new common_sampler {
|
||||||
/* .params = */ gsmpl->params,
|
/* .params = */ gsmpl->params,
|
||||||
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
/* .grammar = */ gsmpl->grammar,
|
|
||||||
/* .prev = */ gsmpl->prev,
|
/* .prev = */ gsmpl->prev,
|
||||||
/* .cur = */ gsmpl->cur,
|
/* .cur = */ gsmpl->cur,
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
|
|
@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||||
return gsmpl->chain;
|
return gsmpl->chain;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
|
@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
|
|
||||||
llama_token id = LLAMA_TOKEN_NULL;
|
llama_token id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
auto & grmr = gsmpl->grmr;
|
||||||
auto & chain = gsmpl->chain;
|
auto & chain = gsmpl->chain;
|
||||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
|
id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
||||||
|
{
|
||||||
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||||
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &single_token_data_array);
|
||||||
|
|
||||||
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
|
if (is_valid) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resampling:
|
||||||
|
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||||
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
|
|
@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (; i < draft.size(); i++) {
|
for (; i < draft.size(); i++) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
|
@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i == draft.size()) {
|
if (i == draft.size()) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
|
@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||||
std::vector<int> idxs(draft.size() + 1);
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
idxs[i] = i;
|
idxs[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||||
// - check if the token fits the grammar (if any)
|
// - check if the token fits the grammar (if any)
|
||||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
//
|
//
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||||
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||||
|
//
|
||||||
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
// generalized version of common_sampler_sample
|
// generalized version of common_sampler_sample
|
||||||
//
|
//
|
||||||
|
|
@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
//
|
//
|
||||||
// returns at least 1 token, up to idxs.size()
|
// returns at least 1 token, up to idxs.size()
|
||||||
//
|
//
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
|
||||||
for (int i = 0; i < params.n_draft; ++i) {
|
for (int i = 0; i < params.n_draft; ++i) {
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
common_sampler_sample(smpl, ctx_dft, 0);
|
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||||
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
|
||||||
bool accept = false;
|
bool accept = false;
|
||||||
if (params.sampling.temp > 0) {
|
if (params.sampling.temp > 0) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
||||||
|
|
||||||
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
||||||
|
|
||||||
|
|
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
||||||
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->accept) {
|
if (smpl->iface->accept) {
|
||||||
smpl->iface->accept(smpl, token);
|
smpl->iface->accept(smpl, token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(smpl->iface->apply);
|
GGML_ASSERT(smpl->iface->apply);
|
||||||
smpl->iface->apply(smpl, cur_p);
|
smpl->iface->apply(smpl, cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->reset) {
|
if (smpl->iface->reset) {
|
||||||
smpl->iface->reset(smpl);
|
smpl->iface->reset(smpl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
||||||
|
if (!smpl) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->clone) {
|
if (smpl->iface->clone) {
|
||||||
return smpl->iface->clone(smpl);
|
return smpl->iface->clone(smpl);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue