sampling : do not create empty samplers

This commit is contained in:
Georgi Gerganov 2025-12-01 17:52:07 +02:00
parent 4032ce2378
commit 04f2822a86
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 134 additions and 61 deletions

View File

@ -468,7 +468,41 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
return token;
}
// backend sampling (empty iface)
// empty sampler
struct llama_sampler_empty {
const char * name;
};
static struct llama_sampler * llama_sampler_init_empty(const char * name);
static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_empty *) smpl->ctx;
return ctx->name;
}
static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
GGML_UNUSED(smpl);
GGML_UNUSED(token);
}
static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
GGML_UNUSED(smpl);
GGML_UNUSED(cur_p);
}
static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
GGML_UNUSED(smpl);
}
static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_empty *) smpl->ctx;
return llama_sampler_init_empty(ctx->name);
}
static void llama_sampler_empty_free(struct llama_sampler * smpl) {
delete (llama_sampler_empty *) smpl->ctx;
}
static void llama_sampler_empty_backend_init(
struct llama_sampler * smpl,
@ -503,6 +537,27 @@ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
GGML_UNUSED(smpl);
}
static struct llama_sampler_i llama_sampler_empty_i = {
/* .name = */ llama_sampler_empty_name,
/* .accept = */ llama_sampler_empty_accept,
/* .apply = */ llama_sampler_empty_apply,
/* .reset = */ llama_sampler_empty_reset,
/* .clone = */ llama_sampler_empty_clone,
/* .free = */ llama_sampler_empty_free,
/* .backend_init = */ llama_sampler_empty_backend_init,
/* .backend_accept = */ llama_sampler_empty_backend_accept,
/* .backend_apply = */ llama_sampler_empty_backend_apply,
/* .backend_set_input = */ llama_sampler_empty_backend_set_input,
};
struct llama_sampler * llama_sampler_init_empty(const char * name) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_empty_i,
/* .ctx = */ new llama_sampler_empty {
/* .name = */ name,
}
);
}
// sampler chain
@ -1040,6 +1095,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
const bool is_empty = (k <= 0);
if (is_empty) {
return llama_sampler_init_empty("top-k?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
@ -1226,6 +1287,12 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
const bool is_empty = (p <= 0.0f);
if (is_empty) {
return llama_sampler_init_empty("top-p?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
@ -1378,6 +1445,12 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
const bool is_empty = (p <= 0.0f);
if (is_empty) {
return llama_sampler_init_empty("min-p?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
@ -1482,24 +1555,19 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
auto * res = llama_sampler_init(
const bool is_empty = (p >= 1.0f);
if (is_empty) {
return llama_sampler_init_empty("typical?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
}
);
const bool is_empty = (p >= 1.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// temp
@ -1535,6 +1603,7 @@ static void llama_sampler_temp_backend_apply(
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
// TODO: this is incorrect - find the most probable token instead
return;
}
@ -1562,6 +1631,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
const bool is_empty = temp == 1.0f;
if (is_empty) {
return llama_sampler_init_empty("temp?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
@ -1662,14 +1737,19 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp_ext *) smpl->ctx;
}
// TODO: deduplicate with llama_sampler_temp_backend_apply
static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
// TODO: implement
GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented");
if (ctx_data->temp <= 0.0f) {
// TODO: this is incorrect - find the most probable token instead
return;
}
@ -1697,6 +1777,12 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
const bool is_empty = temp == 1.0f && delta <= 0.0f;
if (is_empty) {
return llama_sampler_init_empty("temp-ext?");
}
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
@ -1803,9 +1889,15 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
const bool is_empty = (p <= 0.0f || t > 0.5f);
if (is_empty) {
return llama_sampler_init_empty("xtc?");
}
const auto seed_cur = get_rng_seed(seed);
auto * res = llama_sampler_init(
return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
@ -1816,17 +1908,6 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .rng = */ std::mt19937(seed_cur),
}
);
const bool is_empty = (p <= 0.0f || t > 0.5f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// mirostat
@ -1927,7 +2008,8 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
};
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
const auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
@ -2368,7 +2450,13 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
auto * res = llama_sampler_init(
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
if (is_empty) {
return llama_sampler_init_empty("penalties?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
@ -2379,17 +2467,6 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .token_count = */ {},
}
);
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// top-n-sigma
@ -2466,23 +2543,18 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
auto * res = llama_sampler_init(
const bool is_empty = (n <= 0.0f);
if (is_empty) {
return llama_sampler_init_empty("top-n-sigma?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_n_sigma_i,
/* .ctx = */ new llama_sampler_top_n_sigma {
/* .n = */ n,
}
);
const bool is_empty = (n <= 0.0f);
if (is_empty) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// DRY
@ -2818,6 +2890,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
if (!dry_enabled) {
return llama_sampler_init_empty("dry?");
}
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
// Process sequence breakers
for (size_t i = 0; i < num_breakers; ++i) {
@ -2841,7 +2917,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}
auto * res = llama_sampler_init(
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ n_ctx_train,
@ -2855,15 +2931,6 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
}
);
if (!dry_enabled) {
res->iface->backend_init = llama_sampler_empty_backend_init;
res->iface->backend_accept = llama_sampler_empty_backend_accept;
res->iface->backend_apply = llama_sampler_empty_backend_apply;
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
}
return res;
}
// wrapper for test-sampling.cpp
@ -3035,6 +3102,12 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
const bool is_empty = n_logit_bias <= 0;
if (is_empty) {
return llama_sampler_init_empty("logit-bias?");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {