llama : require backend samplers to be of type llama_sampler_chain
This commit is contained in:
parent
07003f1ffb
commit
92ff767918
|
|
@ -369,7 +369,7 @@ extern "C" {
|
|||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL]
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
|
|
@ -1243,7 +1243,15 @@ extern "C" {
|
|||
|
||||
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// return NULL if:
|
||||
// - the sampler is NULL
|
||||
// - the sampler is not a llama_sampler_chain
|
||||
// - the index is out of bounds, unless i == -1
|
||||
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// the total number of samplers in the chain
|
||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||
|
||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ llama_context::llama_context(
|
|||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
// TODO: assert this is a llama_sampler_chain instance
|
||||
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
|
||||
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
||||
}
|
||||
|
||||
if (set_sampler(config.seq_id, config.sampler)) {
|
||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||
|
|
|
|||
|
|
@ -803,7 +803,19 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
|
|||
});
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
|
||||
if (chain == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (chain->iface != &llama_sampler_chain_i) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (i == -1) {
|
||||
return chain;
|
||||
}
|
||||
|
||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||
|
||||
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue