sampling : stop short if backend sampler sampled a token

This commit modifies the graph building logic to immediately continue
when a token has already been sampled by the backend sampler.

It also updates the test for backend temporary sampling to include
top-k and distribution samplers in the chain to verify that they are not
producing any logits (they are not run).
This commit is contained in:
Daniel Bevenius 2025-12-04 08:13:49 +01:00
parent cce3b2a8ad
commit 87b2719eca
No known key found for this signature in database
2 changed files with 3 additions and 0 deletions

View File

@ -2100,6 +2100,7 @@ void llm_graph_context::build_sampling() const {
if (data.sampled != nullptr) { if (data.sampled != nullptr) {
res->t_sampled[seq_id] = data.sampled; res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled); ggml_build_forward_expand(gf, data.sampled);
continue;
} }
if (data.probs != nullptr) { if (data.probs != nullptr) {

View File

@ -441,6 +441,8 @@ static void test_backend_temp_sampling(const char * model_path) {
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(40));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(18));
std::vector<llama_sampler_seq_config> backend_sampler_configs = { std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ seq_id, backend_sampler_chain }, { seq_id, backend_sampler_chain },