llama : require backend samplers to be of type llama_sampler_chain

This commit is contained in:
Georgi Gerganov 2025-12-09 15:38:37 +02:00
parent 07003f1ffb
commit 92ff767918
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 26 additions and 4 deletions

View File

@ -369,7 +369,7 @@ extern "C" {
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL]
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
@ -1243,7 +1243,15 @@ extern "C" {
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
// return NULL if:
// - the sampler is NULL
// - the sampler is not a llama_sampler_chain
// - the index is out of bounds, unless i == -1
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
// the total number of samplers in the chain
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed

View File

@ -66,7 +66,9 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i];
// TODO: assert this is a llama_sampler_chain instance
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
}
if (set_sampler(config.seq_id, config.sampler)) {
const int n_samplers = llama_sampler_chain_n(config.sampler);

View File

@ -803,7 +803,19 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
});
}
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
if (chain == nullptr) {
return nullptr;
}
if (chain->iface != &llama_sampler_chain_i) {
return nullptr;
}
if (i == -1) {
return chain;
}
const auto * p = (const llama_sampler_chain *) chain->ctx;
if (i < 0 || (size_t) i >= p->samplers.size()) {