diff --git a/common/common.cpp b/common/common.cpp index bf81370730..6606ab5c59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1019,7 +1019,6 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - // backend sampling initialization if (params.sampling.backend_sampling) { llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling); if (backend_chain != nullptr) { diff --git a/common/sampling.cpp b/common/sampling.cpp index 9f7795aa41..9f1ce46680 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -165,10 +165,6 @@ struct common_sampler { mutable int64_t t_total_us = 0; }; -static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) { - return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end(); -} - static bool sampler_backend_supported(enum common_sampler_type type) { switch (type) { case COMMON_SAMPLER_TYPE_TOP_K: @@ -180,10 +176,100 @@ static bool sampler_backend_supported(enum common_sampler_type type) { } } +static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) { + switch (type) { + case COMMON_SAMPLER_TYPE_PENALTIES: + if (params.penalty_last_n == 64 && + fabs(params.penalty_repeat) <= 1.0f && + fabs(params.penalty_freq) <= 0.0f && + fabs(params.penalty_present) <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_DRY: + if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + if (params.typ_p == 1.0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + if (params.top_n_sigma == -1.0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + if (params.top_k <= 0) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + if (params.temp <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_MIN_P: + if (params.min_p <= 0.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_TOP_P: + if (params.top_p >= 1.0f) { + return false; + } + break; + case COMMON_SAMPLER_TYPE_XTC: + if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) { + return false; + } + break; + default: + break; + } + return true; +} + static bool has_logit_bias(const struct common_params_sampling & params) { return !params.logit_bias.empty(); } +struct active_samplers { + std::vector backend_samplers; + std::vector cpu_samplers; +}; + +static struct active_samplers get_active_samplers(const struct common_params_sampling & params) { + struct active_samplers result; + + if (params.mirostat != 0) { + // Mirostat is CPU-only and overrides other samplers + for (const auto & sampler_type : params.samplers) { + if (is_sampler_enabled(sampler_type, params)) { + result.cpu_samplers.push_back(sampler_type); + } + } + return result; + } + + bool backend_supported = params.backend_sampling; + + for (const auto & sampler_type : params.samplers) { + if (!is_sampler_enabled(sampler_type, params)) { + continue; + } + + if (backend_supported && sampler_backend_supported(sampler_type)) { + result.backend_samplers.push_back(sampler_type); + } else { + result.cpu_samplers.push_back(sampler_type); + } + } + return result; +} + std::string common_params_sampling::print() const { char result[1024]; @@ -200,6 +286,14 @@ std::string common_params_sampling::print() const { return std::string(result); } +struct backend_chain_data { + struct llama_sampler * chain; + size_t count; +}; + +static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, + struct active_samplers get_active_samplers); + struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -277,69 +371,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co /* .cur_p = */ {}, }; - size_t backend_sampler_count = 0; - if (params.backend_sampling && params.mirostat == 0) { - if (has_logit_bias(params)) { - backend_sampler_count++; - } - - // Find the longest contiguous chain of backend-supported samplers from the start - for (const auto & sampler_type : params.samplers) { - if (sampler_backend_supported(sampler_type)) { - backend_sampler_count++; - } else { - break; - } - } - } - - // If the samplers combination is supported then we can build the backend chain. - if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) { - llama_sampler_chain_params backend_params = llama_sampler_chain_default_params(); - backend_params.no_perf = params.no_perf; - result->backend_chain = llama_sampler_chain_init(backend_params); - - if (has_logit_bias(params)) { - llama_sampler_chain_add(result->backend_chain, - llama_sampler_backend_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); - } - - size_t backend_idx = 0; - for (const auto & sampler_type : params.samplers) { - if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { - break; - } - - switch (sampler_type) { - case COMMON_SAMPLER_TYPE_TOP_K: - if (params.top_k > 0) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k)); - } - backend_idx++; - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (params.temp > 0.0f) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp)); - } - backend_idx++; - break; - case COMMON_SAMPLER_TYPE_MIN_P: - if (params.min_p > 0.0f) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_min_p(params.min_p)); - } - backend_idx++; - break; - default: - GGML_ASSERT(false && "unsupported backend sampler"); - } - } - } - - size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); - bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); + struct active_samplers active_samplers = get_active_samplers(params); + backend_chain_data backend_data = backend_samplers_init(model, params, active_samplers); + result->backend_chain = backend_data.chain; // Build CPU chain if (!params.backend_sampling || !has_logit_bias(params)) { @@ -352,8 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co if (params.mirostat == 0) { // Add remaining CPU samplers - for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) { - const auto & cnstr = params.samplers[i]; + for (const auto & cnstr : active_samplers.cpu_samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: { @@ -398,10 +431,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - // If all samplers are on backend, add dist to backend; otherwise add to CPU - if (result->backend_chain && !cpu_has_samplers) { - llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed)); - } else { + if (!active_samplers.cpu_samplers.empty()) { llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } } else if (params.mirostat == 1) { @@ -417,35 +447,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co return result; } -struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { - if (!params.backend_sampling || params.mirostat != 0) { - return nullptr; + +static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, + struct active_samplers active_samplers) { + if (active_samplers.backend_samplers.empty()) { + return { nullptr, 0 }; } const llama_vocab * vocab = llama_model_get_vocab(model); - // Determine the split point for backend sampling using the same logic as common_sampler_init - size_t backend_sampler_count = 0; - if (has_logit_bias(params)) { - backend_sampler_count++; - } - - // Find the longest contiguous chain of backend-supported samplers from the start - for (const auto & sampler_type : params.samplers) { - if (sampler_backend_supported(sampler_type)) { - backend_sampler_count++; - } else { - break; - } - } - - if (backend_sampler_count == 0 && !has_logit_bias(params)) { - return nullptr; - } - llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); chain_params.no_perf = params.no_perf; - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); // Add logit_bias to backend chain if present @@ -456,46 +468,32 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo params.logit_bias.data())); } - size_t backend_idx = 0; - for (const auto & sampler_type : params.samplers) { - if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { - break; - } - + for (const auto & sampler_type : active_samplers.backend_samplers) { switch (sampler_type) { case COMMON_SAMPLER_TYPE_TOP_K: - if (params.top_k > 0) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (params.temp > 0.0f) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); break; case COMMON_SAMPLER_TYPE_MIN_P: - if (params.min_p > 0.0f) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p)); - } - backend_idx++; + llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p)); break; default: GGML_ASSERT(false && "unsupported backend sampler"); } } - // Determine if we should add dist sampler to backend chain - // Only add it if all samplers from params.samplers are on the backend - size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); - bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); - - if (!cpu_has_samplers) { + if (active_samplers.cpu_samplers.empty()) { llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); } - return chain; + return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) }; +} + +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { + struct active_samplers active_samplers = get_active_samplers(params); + return backend_samplers_init(model, params, active_samplers).chain; } void common_sampler_free(struct common_sampler * gsmpl) {