sampling : use host buffer type for inputs
This commit is contained in:
parent
92ff767918
commit
34b407b41c
|
|
@ -66,7 +66,7 @@ llama_context::llama_context(
|
||||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||||
const auto & config = params.samplers[i];
|
const auto & config = params.samplers[i];
|
||||||
|
|
||||||
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
|
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
||||||
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -922,6 +922,11 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||||
|
|
||||||
if (sampler && can_offload) {
|
if (sampler && can_offload) {
|
||||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
||||||
|
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
||||||
|
if (host_buft) {
|
||||||
|
buft = host_buft;
|
||||||
|
}
|
||||||
|
|
||||||
sampler->iface->backend_init(sampler, buft);
|
sampler->iface->backend_init(sampler, buft);
|
||||||
|
|
||||||
sampling.samplers[seq_id] = sampler;
|
sampling.samplers[seq_id] = sampler;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue