sampling : use host buffer type for inputs

This commit is contained in:
Georgi Gerganov 2025-12-09 17:53:17 +02:00
parent 92ff767918
commit 34b407b41c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 6 additions and 1 deletions

View File

@ -66,7 +66,7 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i];
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
}
@ -922,6 +922,11 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
if (sampler && can_offload) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
if (host_buft) {
buft = host_buft;
}
sampler->iface->backend_init(sampler, buft);
sampling.samplers[seq_id] = sampler;