diff --git a/examples/batched/README.md b/examples/batched/README.md index 6013aab01f..5f2c59e1d6 100644 --- a/examples/batched/README.md +++ b/examples/batched/README.md @@ -42,3 +42,15 @@ llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_print_timings: total time = 4156.04 ms ``` + +### Using backend samplers +It is possible to run this example using backend samplers so that sampling is +performed on the backend device, like a GPU. +```bash +./llama-batched \ + -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \ + -np 4 -kvu \ + --backend_sampling --top-k 80 --backend_dist +``` +The `--verbose` flag can be added to see more detailed output and also show +that the backend samplers are being used. diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 1a5de5928a..e9d1fc95c2 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -64,6 +65,18 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); + std::vector sampler_configs(n_parallel); + if (params.sampling.backend_sampling) { + for (int32_t i = 0; i < n_parallel; ++i) { + llama_sampler * backend_sampler = common_sampler_backend_init(model, params.sampling); + if (backend_sampler) { + sampler_configs[i] = { i, backend_sampler }; + } + } + ctx_params.samplers = sampler_configs.data(); + ctx_params.n_samplers = n_parallel; + } + llama_context * ctx = llama_init_from_model(model, ctx_params); auto sparams = llama_sampler_chain_default_params();