sampling : use host buffer type for inputs
This commit is contained in:
parent
92ff767918
commit
34b407b41c
|
|
@ -66,7 +66,7 @@ llama_context::llama_context(
|
|||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
|
||||
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
||||
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
||||
}
|
||||
|
||||
|
|
@ -922,6 +922,11 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|||
|
||||
if (sampler && can_offload) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
|
||||
sampler->iface->backend_init(sampler, buft);
|
||||
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
|
|
|
|||
Loading…
Reference in New Issue