From 92ff76791834b8c746a4bd3b047b72ec09a5b184 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 9 Dec 2025 15:38:37 +0200 Subject: [PATCH] llama : require backend samplers to be of type llama_sampler_chain --- include/llama.h | 12 ++++++++++-- src/llama-context.cpp | 4 +++- src/llama-sampling.cpp | 14 +++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/include/llama.h b/include/llama.h index e01d06766d..abe71e7560 100644 --- a/include/llama.h +++ b/include/llama.h @@ -369,7 +369,7 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL] // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; @@ -1243,7 +1243,15 @@ extern "C" { // important: takes ownership of the sampler object and will free it when llama_sampler_free is called LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + + // return NULL if: + // - the sampler is NULL + // - the sampler is not a llama_sampler_chain + // - the index is out of bounds, unless i == -1 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); + + // the total number of samplers in the chain LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3461b20ebc..93fd51df04 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -66,7 +66,9 @@ llama_context::llama_context( for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; - // TODO: assert this is a llama_sampler_chain instance + if (llama_sampler_chain_get(config.sampler, -1) != nullptr) { + throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); + } if (set_sampler(config.seq_id, config.sampler)) { const int n_samplers = llama_sampler_chain_n(config.sampler); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9eee48f753..2c1127666f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -803,7 +803,19 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler }); } -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { + if (chain == nullptr) { + return nullptr; + } + + if (chain->iface != &llama_sampler_chain_i) { + return nullptr; + } + + if (i == -1) { + return chain; + } + const auto * p = (const llama_sampler_chain *) chain->ctx; if (i < 0 || (size_t) i >= p->samplers.size()) {