common : simplify sampler chain initialization

This commit is contained in:
Georgi Gerganov 2025-12-01 17:10:32 +02:00
parent 217469f07f
commit 4032ce2378
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 171 additions and 178 deletions

View File

@ -212,18 +212,12 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool backend_sampling = false; // enable backend sampling bool backend_sampling = false;
bool has_logit_bias() const { bool has_logit_bias() const {
return !logit_bias.empty(); return !logit_bias.empty();
} }
bool is_disabled(enum common_sampler_type type) const;
// remove disabled samplers
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
void filter_disabled();
// print the parameters into a string // print the parameters into a string
std::string print() const; std::string print() const;
}; };
@ -661,7 +655,7 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
struct common_sampler; struct common_sampler;
// note: defines object's lifetime // note: defines the model, context, samplers, ets. lifetimes
struct common_init_result { struct common_init_result {
common_init_result(common_params & params); common_init_result(common_params & params);
~common_init_result(); ~common_init_result();

View File

@ -163,84 +163,6 @@ struct common_sampler {
mutable int64_t t_total_us = 0; mutable int64_t t_total_us = 0;
}; };
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
case COMMON_SAMPLER_TYPE_MIN_P:
case COMMON_SAMPLER_TYPE_TOP_P:
return true;
default:
return false;
}
}
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
switch (type) {
case COMMON_SAMPLER_TYPE_PENALTIES:
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_DRY:
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
if (typ_p >= 1.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
if (top_n_sigma <= 0.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
if (top_k <= 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (dynatemp_range <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (min_p <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_P:
if (top_p >= 1.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_XTC:
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
return true;
}
break;
default:
break;
}
return false;
}
void common_params_sampling::filter_disabled() {
for (auto it = samplers.begin(); it != samplers.end();) {
if (is_disabled(*it)) {
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
it = samplers.erase(it);
} else {
++it;
}
}
}
std::string common_params_sampling::print() const { std::string common_params_sampling::print() const {
char result[1024]; char result[1024];
@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
return std::string(result); return std::string(result);
} }
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
} }
} }
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
if (params.backend_sampling) {
params.filter_disabled();
}
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ grmr, /* .grmr = */ grmr,
@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
/* .cur_p = */ {}, /* .cur_p = */ {},
}; };
size_t idx_smpl = 0; std::vector<llama_sampler *> samplers;
bool is_backend = true;
is_backend = is_backend && params.backend_sampling;
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
if (params.has_logit_bias()) { if (params.has_logit_bias()) {
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
} }
if (params.mirostat == 0) { if (params.mirostat == 0) {
// backend samplers are added first for (const auto & cnstr : params.samplers) {
while (is_backend && idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
if (!common_sampler_type_has_backend_support(cnstr)) {
is_backend = false;
--idx_smpl;
break;
}
switch (cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
// Add remaining CPU samplers
while (idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: case COMMON_SAMPLER_TYPE_DRY:
{ {
@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); samplers.push_back(llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); samplers.push_back(llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed)); samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else { } else {
GGML_ASSERT(false && "unknown mirostat version"); GGML_ASSERT(false && "unknown mirostat version");
} }
bool is_backend = params.backend_sampling;
// split in two chains: backend -> CPU
for (auto * smpl : samplers) {
if (!smpl->iface->backend_apply) {
is_backend = false;
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
}
return result; return result;
} }

View File

@ -36,8 +36,7 @@ struct common_sampler;
// llama_sampler API overloads // llama_sampler API overloads
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND] struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl); void common_sampler_free(struct common_sampler * gsmpl);

View File

@ -73,16 +73,10 @@ int main(int argc, char ** argv) {
for (int32_t i = 0; i < n_parallel; ++i) { for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
if (params.sampling.backend_sampling) { llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
}
sampler_configs.push_back({ i, smpl }); sampler_configs.push_back({ i, smpl });
} }

View File

@ -1212,7 +1212,7 @@ extern "C" {
}; };
struct llama_sampler { struct llama_sampler {
const struct llama_sampler_i * iface; struct llama_sampler_i * iface;
llama_sampler_context_t ctx; llama_sampler_context_t ctx;
}; };
@ -1220,7 +1220,7 @@ extern "C" {
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i: // mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

View File

@ -2102,7 +2102,7 @@ void llm_graph_context::build_sampling() const {
ggml_build_forward_expand(gf, data.sampled); ggml_build_forward_expand(gf, data.sampled);
} }
if (data.probs != nullptr) { if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = data.probs; res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs); ggml_build_forward_expand(gf, data.probs);
} }

View File

@ -349,7 +349,7 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API // llama_sampler API
struct llama_sampler * llama_sampler_init( struct llama_sampler * llama_sampler_init(
const struct llama_sampler_i * iface, struct llama_sampler_i * iface,
llama_sampler_context_t ctx) { llama_sampler_context_t ctx) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ iface, /* .iface = */ iface,
@ -468,6 +468,42 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
return token; return token;
} }
// backend sampling (empty iface)
static void llama_sampler_empty_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
}
static void llama_sampler_empty_backend_accept(
struct llama_sampler * smpl,
ggml_context * ctx,
ggml_cgraph * gf,
struct ggml_tensor * selected_token) {
GGML_UNUSED(smpl);
GGML_UNUSED(ctx);
GGML_UNUSED(gf);
GGML_UNUSED(selected_token);
}
static void llama_sampler_empty_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(smpl);
GGML_UNUSED(ctx);
GGML_UNUSED(gf);
GGML_UNUSED(data);
}
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
GGML_UNUSED(smpl);
}
// sampler chain // sampler chain
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@ -1171,7 +1207,7 @@ static void llama_sampler_top_p_backend_apply(
ggml_set_output(data->candidates); ggml_set_output(data->candidates);
ggml_build_forward_expand(gf, data->candidates); ggml_build_forward_expand(gf, data->candidates);
ggml_set_output(data->logits); ggml_set_output(data->logits);
ggml_build_forward_expand(gf, data->logits); ggml_build_forward_expand(gf, data->logits);
} }
@ -1446,13 +1482,24 @@ static struct llama_sampler_i llama_sampler_typical_i = {
}; };
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return llama_sampler_init( auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i, /* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical { /* .ctx = */ new llama_sampler_typical {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
} }
); );
const bool is_empty = (p >= 1.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
} }
// temp // temp
@ -1615,6 +1662,27 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp_ext *) smpl->ctx; delete (llama_sampler_temp_ext *) smpl->ctx;
} }
static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
data->logits = ggml_cont(ctx, scaled);
ggml_set_name(data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, data->logits);
}
static struct llama_sampler_i llama_sampler_temp_ext_i = { static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ llama_sampler_temp_ext_name, /* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
@ -1629,7 +1697,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
}; };
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return llama_sampler_init( auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i, /* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext { /* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp, /* .temp = */ temp,
@ -1637,6 +1705,14 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
/* .exponent = */ exponent, /* .exponent = */ exponent,
} }
); );
const bool is_backend = delta <= 0.0f;
if (is_backend) {
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
}
return res;
} }
// xtc // xtc
@ -1727,8 +1803,9 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
}; };
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed); const auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i, /* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc { /* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p, /* .probability = */ p,
@ -1739,6 +1816,17 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
} }
); );
const bool is_empty = (p <= 0.0f || t > 0.5f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
} }
// mirostat // mirostat
@ -2280,7 +2368,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) { float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0); penalty_last_n = std::max(penalty_last_n, 0);
return llama_sampler_init( auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i, /* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties { /* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n, /* .penalty_last_n = */ penalty_last_n,
@ -2291,6 +2379,17 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .token_count = */ {}, /* .token_count = */ {},
} }
); );
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
} }
// top-n-sigma // top-n-sigma
@ -2317,9 +2416,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
// Only count non-negative infinity values // Only count non-negative infinity values
if (cur_p->data[i].logit != -INFINITY) { if (cur_p->data[i].logit != -INFINITY) {
if (cur_p->data[i].logit > max) { max = std::max(max, cur_p->data[i].logit);
max = cur_p->data[i].logit;
}
logits_sum += cur_p->data[i].logit; logits_sum += cur_p->data[i].logit;
valid_count++; valid_count++;
} }
@ -2369,12 +2466,23 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
}; };
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
return llama_sampler_init( auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_top_n_sigma_i, /* .iface = */ &llama_sampler_top_n_sigma_i,
/* .ctx = */ new llama_sampler_top_n_sigma { /* .ctx = */ new llama_sampler_top_n_sigma {
/* .n = */ n, /* .n = */ n,
} }
); );
const bool is_empty = (n <= 0.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
} }
// DRY // DRY
@ -2733,7 +2841,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
} }
} }
return llama_sampler_init( auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i, /* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry { /* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ n_ctx_train, /* .total_context_size = */ n_ctx_train,
@ -2747,6 +2855,15 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
} }
); );
if (!dry_enabled) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
} }
// wrapper for test-sampling.cpp // wrapper for test-sampling.cpp
@ -2854,6 +2971,8 @@ static void llama_sampler_logit_bias_backend_apply(
// Add the sparse logit logit_bias to the logits // Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias); struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
data->logits = logit_biased;
ggml_build_forward_expand(gf, logit_biased); ggml_build_forward_expand(gf, logit_biased);
} }