From 04f2822a868333797c0d12c79c1a9ff776466aa8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 1 Dec 2025 17:52:07 +0200 Subject: [PATCH] sampling : do not create empty samplers --- src/llama-sampling.cpp | 195 ++++++++++++++++++++++++++++------------- 1 file changed, 134 insertions(+), 61 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ca6c3670b1..cf5b0e010f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -468,7 +468,41 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte return token; } -// backend sampling (empty iface) +// empty sampler + +struct llama_sampler_empty { + const char * name; +}; + +static struct llama_sampler * llama_sampler_init_empty(const char * name); + +static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return ctx->name; +} + +static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { + GGML_UNUSED(smpl); + GGML_UNUSED(token); +} + +static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + GGML_UNUSED(smpl); + GGML_UNUSED(cur_p); +} + +static void llama_sampler_empty_reset(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return llama_sampler_init_empty(ctx->name); +} + +static void llama_sampler_empty_free(struct llama_sampler * smpl) { + delete (llama_sampler_empty *) smpl->ctx; +} static void llama_sampler_empty_backend_init( struct llama_sampler * smpl, @@ -503,6 +537,27 @@ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { GGML_UNUSED(smpl); } +static struct llama_sampler_i llama_sampler_empty_i = { + /* .name = */ llama_sampler_empty_name, + /* .accept = */ llama_sampler_empty_accept, + /* .apply = */ llama_sampler_empty_apply, + /* .reset = */ llama_sampler_empty_reset, + /* .clone = */ llama_sampler_empty_clone, + /* .free = */ llama_sampler_empty_free, + /* .backend_init = */ llama_sampler_empty_backend_init, + /* .backend_accept = */ llama_sampler_empty_backend_accept, + /* .backend_apply = */ llama_sampler_empty_backend_apply, + /* .backend_set_input = */ llama_sampler_empty_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_empty(const char * name) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_empty_i, + /* .ctx = */ new llama_sampler_empty { + /* .name = */ name, + } + ); +} // sampler chain @@ -1040,6 +1095,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = { }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + const bool is_empty = (k <= 0); + + if (is_empty) { + return llama_sampler_init_empty("top-k?"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { @@ -1226,6 +1287,12 @@ static struct llama_sampler_i llama_sampler_top_p_i = { }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("top-p?"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { @@ -1378,6 +1445,12 @@ static struct llama_sampler_i llama_sampler_min_p_i = { }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("min-p?"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { @@ -1482,24 +1555,19 @@ static struct llama_sampler_i llama_sampler_typical_i = { }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { - auto * res = llama_sampler_init( + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + return llama_sampler_init_empty("typical?"); + } + + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { /* .p = */ p, /* .min_keep = */ min_keep, } ); - - const bool is_empty = (p >= 1.0f); - - if (is_empty) { - res->iface->backend_init = llama_sampler_empty_backend_init; - res->iface->backend_accept = llama_sampler_empty_backend_accept; - res->iface->backend_apply = llama_sampler_empty_backend_apply; - res->iface->backend_set_input = llama_sampler_empty_backend_set_input; - } - - return res; } // temp @@ -1535,6 +1603,7 @@ static void llama_sampler_temp_backend_apply( auto * ctx_data = (llama_sampler_temp *) smpl->ctx; if (ctx_data->temp <= 0.0f) { + // TODO: this is incorrect - find the most probable token instead return; } @@ -1562,6 +1631,12 @@ static struct llama_sampler_i llama_sampler_temp_i = { }; struct llama_sampler * llama_sampler_init_temp(float temp) { + const bool is_empty = temp == 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("temp?"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { @@ -1662,14 +1737,19 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +// TODO: deduplicate with llama_sampler_temp_backend_apply static void llama_sampler_temp_ext_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data) { - auto * ctx_data = (llama_sampler_temp *) smpl->ctx; + auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx; + + // TODO: implement + GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented"); if (ctx_data->temp <= 0.0f) { + // TODO: this is incorrect - find the most probable token instead return; } @@ -1697,6 +1777,12 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + const bool is_empty = temp == 1.0f && delta <= 0.0f; + + if (is_empty) { + return llama_sampler_init_empty("temp-ext?"); + } + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { @@ -1803,9 +1889,15 @@ static struct llama_sampler_i llama_sampler_xtc_i = { }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + return llama_sampler_init_empty("xtc?"); + } + const auto seed_cur = get_rng_seed(seed); - auto * res = llama_sampler_init( + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { /* .probability = */ p, @@ -1816,17 +1908,6 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, /* .rng = */ std::mt19937(seed_cur), } ); - - const bool is_empty = (p <= 0.0f || t > 0.5f); - - if (is_empty) { - res->iface->backend_init = llama_sampler_empty_backend_init; - res->iface->backend_accept = llama_sampler_empty_backend_accept; - res->iface->backend_apply = llama_sampler_empty_backend_apply; - res->iface->backend_set_input = llama_sampler_empty_backend_set_input; - } - - return res; } // mirostat @@ -1927,7 +2008,8 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { - auto seed_cur = get_rng_seed(seed); + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { @@ -2368,7 +2450,13 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); - auto * res = llama_sampler_init( + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + return llama_sampler_init_empty("penalties?"); + } + + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { /* .penalty_last_n = */ penalty_last_n, @@ -2379,17 +2467,6 @@ struct llama_sampler * llama_sampler_init_penalties( /* .token_count = */ {}, } ); - - const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); - - if (is_empty) { - res->iface->backend_init = llama_sampler_empty_backend_init; - res->iface->backend_accept = llama_sampler_empty_backend_accept; - res->iface->backend_apply = llama_sampler_empty_backend_apply; - res->iface->backend_set_input = llama_sampler_empty_backend_set_input; - } - - return res; } // top-n-sigma @@ -2466,23 +2543,18 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = { }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { - auto * res = llama_sampler_init( + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("top-n-sigma?"); + } + + return llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { /* .n = */ n, } ); - - const bool is_empty = (n <= 0.0f); - - if (is_empty) { - res->iface->backend_init = llama_sampler_empty_backend_init; - res->iface->backend_accept = llama_sampler_empty_backend_accept; - res->iface->backend_apply = llama_sampler_empty_backend_apply; - res->iface->backend_set_input = llama_sampler_empty_backend_set_input; - } - - return res; } // DRY @@ -2818,6 +2890,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + if (!dry_enabled) { + return llama_sampler_init_empty("dry?"); + } + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { // Process sequence breakers for (size_t i = 0; i < num_breakers; ++i) { @@ -2841,7 +2917,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, } } - auto * res = llama_sampler_init( + return llama_sampler_init( /* .iface = */ &llama_sampler_dry_i, /* .ctx = */ new llama_sampler_dry { /* .total_context_size = */ n_ctx_train, @@ -2855,15 +2931,6 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), } ); - - if (!dry_enabled) { - res->iface->backend_init = llama_sampler_empty_backend_init; - res->iface->backend_accept = llama_sampler_empty_backend_accept; - res->iface->backend_apply = llama_sampler_empty_backend_apply; - res->iface->backend_set_input = llama_sampler_empty_backend_set_input; - } - - return res; } // wrapper for test-sampling.cpp @@ -3035,6 +3102,12 @@ struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + const bool is_empty = n_logit_bias <= 0; + + if (is_empty) { + return llama_sampler_init_empty("logit-bias?"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias {