common : simplify sampler chain initialization

This commit is contained in:
Georgi Gerganov 2025-12-01 17:10:32 +02:00
parent 217469f07f
commit 4032ce2378
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 171 additions and 178 deletions

View File

@ -212,18 +212,12 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool backend_sampling = false; // enable backend sampling
bool backend_sampling = false;
bool has_logit_bias() const {
return !logit_bias.empty();
}
bool is_disabled(enum common_sampler_type type) const;
// remove disabled samplers
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
void filter_disabled();
// print the parameters into a string
std::string print() const;
};
@ -661,7 +655,7 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
struct common_sampler;
// note: defines object's lifetime
// note: defines the model, context, samplers, ets. lifetimes
struct common_init_result {
common_init_result(common_params & params);
~common_init_result();

View File

@ -163,84 +163,6 @@ struct common_sampler {
mutable int64_t t_total_us = 0;
};
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
case COMMON_SAMPLER_TYPE_MIN_P:
case COMMON_SAMPLER_TYPE_TOP_P:
return true;
default:
return false;
}
}
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
switch (type) {
case COMMON_SAMPLER_TYPE_PENALTIES:
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_DRY:
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
if (typ_p >= 1.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
if (top_n_sigma <= 0.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
if (top_k <= 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (dynatemp_range <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (min_p <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_P:
if (top_p >= 1.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_XTC:
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
return true;
}
break;
default:
break;
}
return false;
}
void common_params_sampling::filter_disabled() {
for (auto it = samplers.begin(); it != samplers.end();) {
if (is_disabled(*it)) {
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
it = samplers.erase(it);
} else {
++it;
}
}
}
std::string common_params_sampling::print() const {
char result[1024];
@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
if (params.backend_sampling) {
params.filter_disabled();
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
/* .cur_p = */ {},
};
size_t idx_smpl = 0;
bool is_backend = true;
is_backend = is_backend && params.backend_sampling;
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
std::vector<llama_sampler *> samplers;
if (params.has_logit_bias()) {
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}
if (params.mirostat == 0) {
// backend samplers are added first
while (is_backend && idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
if (!common_sampler_type_has_backend_support(cnstr)) {
is_backend = false;
--idx_smpl;
break;
}
switch (cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
// Add remaining CPU samplers
while (idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
bool is_backend = params.backend_sampling;
// split in two chains: backend -> CPU
for (auto * smpl : samplers) {
if (!smpl->iface->backend_apply) {
is_backend = false;
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
}
return result;
}

View File

@ -36,8 +36,7 @@ struct common_sampler;
// llama_sampler API overloads
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND]
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl);

View File

@ -73,16 +73,10 @@ int main(int argc, char ** argv) {
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
if (params.sampling.backend_sampling) {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
}
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
sampler_configs.push_back({ i, smpl });
}

View File

@ -1212,7 +1212,7 @@ extern "C" {
};
struct llama_sampler {
const struct llama_sampler_i * iface;
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};
@ -1220,7 +1220,7 @@ extern "C" {
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

View File

@ -2102,7 +2102,7 @@ void llm_graph_context::build_sampling() const {
ggml_build_forward_expand(gf, data.sampled);
}
if (data.probs != nullptr) {
if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
}

View File

@ -349,7 +349,7 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API
struct llama_sampler * llama_sampler_init(
const struct llama_sampler_i * iface,
struct llama_sampler_i * iface,
llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
@ -468,6 +468,42 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
return token;
}
// backend sampling (empty iface)
static void llama_sampler_empty_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
}
static void llama_sampler_empty_backend_accept(
struct llama_sampler * smpl,
ggml_context * ctx,
ggml_cgraph * gf,
struct ggml_tensor * selected_token) {
GGML_UNUSED(smpl);
GGML_UNUSED(ctx);
GGML_UNUSED(gf);
GGML_UNUSED(selected_token);
}
static void llama_sampler_empty_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(smpl);
GGML_UNUSED(ctx);
GGML_UNUSED(gf);
GGML_UNUSED(data);
}
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
GGML_UNUSED(smpl);
}
// sampler chain
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@ -1171,7 +1207,7 @@ static void llama_sampler_top_p_backend_apply(
ggml_set_output(data->candidates);
ggml_build_forward_expand(gf, data->candidates);
ggml_set_output(data->logits);
ggml_build_forward_expand(gf, data->logits);
}
@ -1446,13 +1482,24 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
}
);
const bool is_empty = (p >= 1.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// temp
@ -1615,6 +1662,27 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp_ext *) smpl->ctx;
}
static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
data->logits = ggml_cont(ctx, scaled);
ggml_set_name(data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, data->logits);
}
static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr,
@ -1629,7 +1697,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
@ -1637,6 +1705,14 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
/* .exponent = */ exponent,
}
);
const bool is_backend = delta <= 0.0f;
if (is_backend) {
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
}
return res;
}
// xtc
@ -1727,8 +1803,9 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
const auto seed_cur = get_rng_seed(seed);
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
@ -1739,6 +1816,17 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .rng = */ std::mt19937(seed_cur),
}
);
const bool is_empty = (p <= 0.0f || t > 0.5f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// mirostat
@ -2280,7 +2368,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
@ -2291,6 +2379,17 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .token_count = */ {},
}
);
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// top-n-sigma
@ -2317,9 +2416,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
for (size_t i = 0; i < cur_p->size; ++i) {
// Only count non-negative infinity values
if (cur_p->data[i].logit != -INFINITY) {
if (cur_p->data[i].logit > max) {
max = cur_p->data[i].logit;
}
max = std::max(max, cur_p->data[i].logit);
logits_sum += cur_p->data[i].logit;
valid_count++;
}
@ -2369,12 +2466,23 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_top_n_sigma_i,
/* .ctx = */ new llama_sampler_top_n_sigma {
/* .n = */ n,
}
);
const bool is_empty = (n <= 0.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// DRY
@ -2733,7 +2841,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ n_ctx_train,
@ -2747,6 +2855,15 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
}
);
if (!dry_enabled) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// wrapper for test-sampling.cpp
@ -2854,6 +2971,8 @@ static void llama_sampler_logit_bias_backend_apply(
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
data->logits = logit_biased;
ggml_build_forward_expand(gf, logit_biased);
}