diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 93fd51df04..73432e5d04 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -66,7 +66,7 @@ llama_context::llama_context( for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; - if (llama_sampler_chain_get(config.sampler, -1) != nullptr) { + if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); } @@ -922,6 +922,11 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { if (sampler && can_offload) { ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); + auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); + if (host_buft) { + buft = host_buft; + } + sampler->iface->backend_init(sampler, buft); sampling.samplers[seq_id] = sampler;