llama : require backend samplers to be of type llama_sampler_chain
This commit is contained in:
parent
07003f1ffb
commit
92ff767918
|
|
@ -369,7 +369,7 @@ extern "C" {
|
||||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||||
|
|
||||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL]
|
||||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||||
struct llama_sampler_seq_config * samplers;
|
struct llama_sampler_seq_config * samplers;
|
||||||
size_t n_samplers;
|
size_t n_samplers;
|
||||||
|
|
@ -1243,7 +1243,15 @@ extern "C" {
|
||||||
|
|
||||||
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||||
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
|
||||||
|
// return NULL if:
|
||||||
|
// - the sampler is NULL
|
||||||
|
// - the sampler is not a llama_sampler_chain
|
||||||
|
// - the index is out of bounds, unless i == -1
|
||||||
|
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
|
||||||
|
|
||||||
|
// the total number of samplers in the chain
|
||||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||||
|
|
||||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,9 @@ llama_context::llama_context(
|
||||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||||
const auto & config = params.samplers[i];
|
const auto & config = params.samplers[i];
|
||||||
|
|
||||||
// TODO: assert this is a llama_sampler_chain instance
|
if (llama_sampler_chain_get(config.sampler, -1) != nullptr) {
|
||||||
|
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
||||||
|
}
|
||||||
|
|
||||||
if (set_sampler(config.seq_id, config.sampler)) {
|
if (set_sampler(config.seq_id, config.sampler)) {
|
||||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||||
|
|
|
||||||
|
|
@ -803,7 +803,19 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
|
||||||
|
if (chain == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chain->iface != &llama_sampler_chain_i) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == -1) {
|
||||||
|
return chain;
|
||||||
|
}
|
||||||
|
|
||||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||||
|
|
||||||
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue