diff --git a/common/common.cpp b/common/common.cpp index 6a6f5fec3d..8f2dfd8215 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -976,8 +976,6 @@ struct common_init_result common_init_context_from_model( const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); - cparams.samplers = params.backend_samplers; - cparams.n_samplers = params.n_backend_samplers; llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { @@ -1247,6 +1245,9 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; + cparams.samplers = params.backend_samplers; + cparams.n_samplers = params.n_backend_samplers; + return cparams; } diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 64855b646f..cae778e551 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -146,8 +146,9 @@ int main(int argc, char ** argv) { // Configure backend sampler if configured llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams); + llama_sampler_seq_config sampler_config = { 0, backend_sampler }; + if (backend_sampler) { - llama_sampler_seq_config sampler_config = { 0, backend_sampler }; params.backend_samplers = &sampler_config; params.n_backend_samplers = 1; }