From 0f7805f32a95fff2aa35f7ef34c3d6cb8f5d456c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 26 Nov 2025 13:12:36 +0100 Subject: [PATCH] common : add get_active_samplers function to check enabled samplers This commit adds a function to check if a sampler is actually enabled, meaning that it does not have values that disables its effect. This is then used by the backend samplers initialization to avoid considering samplers that are not enabled when determining the split point between them. The motivation for this is that this allows the default sampler chain for `--samplers` to be used and any sampler that is not enabled will not cause the backend samplers to be skipped. For example, before this change if the penalties sampler was included in the samplers list but had default values that disable it, it would cause the backend samplers to be skipped entirely. This commit also contains some refactoring to remove some code duplication. --- common/common.cpp | 1 - common/sampling.cpp | 240 ++++++++++++++++++++++---------------------- 2 files changed, 119 insertions(+), 122 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index bf81370730..6606ab5c59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1019,7 +1019,6 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - // backend sampling initialization if (params.sampling.backend_sampling) { llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling); if (backend_chain != nullptr) { diff --git a/common/sampling.cpp b/common/sampling.cpp index 9f7795aa41..9f1ce46680 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -165,10 +165,6 @@ struct common_sampler { mutable int64_t t_total_us = 0; }; -static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) { - return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end(); -} - static bool sampler_backend_supported(enum common_sampler_type type) { switch (type) { case COMMON_SAMPLER_TYPE_TOP_K: @@ -180,10 +176,100 @@ static bool sampler_backend_supported(enum common_sampler_type type) { } } +static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) { + switch (type) { + case COMMON_SAMPLER_TYPE_PENALTIES: + if (params.penalty_last_n == 64 && + fabs(params.penalty_repeat) <= 1.0f && + fabs(params.penalty_freq) <= 0.0f && + fabs(params.penalty_present) <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_DRY: + if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + if (params.typ_p == 1.0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + if (params.top_n_sigma == -1.0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + if (params.top_k <= 0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + if (params.temp <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_MIN_P: + if (params.min_p <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_P: + if (params.top_p >= 1.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_XTC: + if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) { + return false; + } + break; + default: + break; + } + return true; +} + static bool has_logit_bias(const struct common_params_sampling & params) { return !params.logit_bias.empty(); } +struct active_samplers { + std::vector backend_samplers; + std::vector cpu_samplers; +}; + +static struct active_samplers get_active_samplers(const struct common_params_sampling & params) { + struct active_samplers result; + + if (params.mirostat != 0) { + // Mirostat is CPU-only and overrides other samplers + for (const auto & sampler_type : params.samplers) { + if (is_sampler_enabled(sampler_type, params)) { + result.cpu_samplers.push_back(sampler_type); + } + } + return result; + } + + bool backend_supported = params.backend_sampling; + + for (const auto & sampler_type : params.samplers) { + if (!is_sampler_enabled(sampler_type, params)) { + continue; + } + + if (backend_supported && sampler_backend_supported(sampler_type)) { + result.backend_samplers.push_back(sampler_type); + } else { + result.cpu_samplers.push_back(sampler_type); + } + } + return result; +} + std::string common_params_sampling::print() const { char result[1024]; @@ -200,6 +286,14 @@ std::string common_params_sampling::print() const { return std::string(result); } +struct backend_chain_data { + struct llama_sampler * chain; + size_t count; +}; + +static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, + struct active_samplers get_active_samplers); + struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -277,69 +371,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co /* .cur_p = */ {}, }; - size_t backend_sampler_count = 0; - if (params.backend_sampling && params.mirostat == 0) { - if (has_logit_bias(params)) { - backend_sampler_count++; - } - - // Find the longest contiguous chain of backend-supported samplers from the start - for (const auto & sampler_type : params.samplers) { - if (sampler_backend_supported(sampler_type)) { - backend_sampler_count++; - } else { - break; - } - } - } - - // If the samplers combination is supported then we can build the backend chain. - if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) { - llama_sampler_chain_params backend_params = llama_sampler_chain_default_params(); - backend_params.no_perf = params.no_perf; - result->backend_chain = llama_sampler_chain_init(backend_params); - - if (has_logit_bias(params)) { - llama_sampler_chain_add(result->backend_chain, - llama_sampler_backend_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); - } - - size_t backend_idx = 0; - for (const auto & sampler_type : params.samplers) { - if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { - break; - } - - switch (sampler_type) { - case COMMON_SAMPLER_TYPE_TOP_K: - if (params.top_k > 0) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k)); - } - backend_idx++; - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (params.temp > 0.0f) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp)); - } - backend_idx++; - break; - case COMMON_SAMPLER_TYPE_MIN_P: - if (params.min_p > 0.0f) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_min_p(params.min_p)); - } - backend_idx++; - break; - default: - GGML_ASSERT(false && "unsupported backend sampler"); - } - } - } - - size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); - bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); + struct active_samplers active_samplers = get_active_samplers(params); + backend_chain_data backend_data = backend_samplers_init(model, params, active_samplers); + result->backend_chain = backend_data.chain; // Build CPU chain if (!params.backend_sampling || !has_logit_bias(params)) { @@ -352,8 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co if (params.mirostat == 0) { // Add remaining CPU samplers - for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) { - const auto & cnstr = params.samplers[i]; + for (const auto & cnstr : active_samplers.cpu_samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: { @@ -398,10 +431,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - // If all samplers are on backend, add dist to backend; otherwise add to CPU - if (result->backend_chain && !cpu_has_samplers) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed)); - } else { + if (!active_samplers.cpu_samplers.empty()) { llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } } else if (params.mirostat == 1) { @@ -417,35 +447,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co return result; } -struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { - if (!params.backend_sampling || params.mirostat != 0) { - return nullptr; + +static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, + struct active_samplers active_samplers) { + if (active_samplers.backend_samplers.empty()) { + return { nullptr, 0 }; } const llama_vocab * vocab = llama_model_get_vocab(model); - // Determine the split point for backend sampling using the same logic as common_sampler_init - size_t backend_sampler_count = 0; - if (has_logit_bias(params)) { - backend_sampler_count++; - } - - // Find the longest contiguous chain of backend-supported samplers from the start - for (const auto & sampler_type : params.samplers) { - if (sampler_backend_supported(sampler_type)) { - backend_sampler_count++; - } else { - break; - } - } - - if (backend_sampler_count == 0 && !has_logit_bias(params)) { - return nullptr; - } - llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); chain_params.no_perf = params.no_perf; - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); // Add logit_bias to backend chain if present @@ -456,46 +468,32 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo params.logit_bias.data())); } - size_t backend_idx = 0; - for (const auto & sampler_type : params.samplers) { - if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { - break; - } - + for (const auto & sampler_type : active_samplers.backend_samplers) { switch (sampler_type) { case COMMON_SAMPLER_TYPE_TOP_K: - if (params.top_k > 0) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (params.temp > 0.0f) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); break; case COMMON_SAMPLER_TYPE_MIN_P: - if (params.min_p > 0.0f) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p)); break; default: GGML_ASSERT(false && "unsupported backend sampler"); } } - // Determine if we should add dist sampler to backend chain - // Only add it if all samplers from params.samplers are on the backend - size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); - bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); - - if (!cpu_has_samplers) { + if (active_samplers.cpu_samplers.empty()) { llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); } - return chain; + return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) }; +} + +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { + struct active_samplers active_samplers = get_active_samplers(params); + return backend_samplers_init(model, params, active_samplers).chain; } void common_sampler_free(struct common_sampler * gsmpl) {