diff --git a/common/sampling.cpp b/common/sampling.cpp index bccec35dac..ebe61f32ca 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -309,15 +309,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { + if (!params.backend_sampling) { + return nullptr; + } const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); chain_params.no_perf = params.no_perf; struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - if (!params.backend_sampling) { - return chain; // return empty chain - } const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE); const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);