diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a621c4ebf5..c0ff7d1791 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2100,6 +2100,7 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; ggml_build_forward_expand(gf, data.sampled); + continue; } if (data.probs != nullptr) { diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index f56cce6350..eb3a0e248d 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -441,6 +441,8 @@ static void test_backend_temp_sampling(const char * model_path) { struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(40)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(18)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain },