sampling : check backend support during init

This commit is contained in:
Georgi Gerganov 2025-12-04 17:29:08 +02:00
parent 1bde70785d
commit 6958d41366
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
8 changed files with 369 additions and 178 deletions

View File

@ -1098,8 +1098,7 @@ common_init_result::common_init_result(common_params & params) :
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
pimpl->samplers_seq_config[i] = { i, backend_chain };
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}
cparams.samplers = pimpl->samplers_seq_config.data();

View File

@ -106,7 +106,6 @@ struct common_sampler {
struct llama_sampler * grmr;
struct llama_sampler * chain;
struct llama_sampler * chain_backend;
ring_buffer<llama_token> prev;
@ -119,7 +118,6 @@ struct common_sampler {
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
llama_sampler_reset(chain_backend);
}
void set_logits(struct llama_context * ctx, int idx) {
@ -247,13 +245,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .chain_backend = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
std::vector<llama_sampler *> samplers;
@ -318,15 +315,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ASSERT(false && "unknown mirostat version");
}
bool is_backend = params.backend_sampling;
// split in two chains: backend -> CPU
for (auto * smpl : samplers) {
if (!smpl->iface->backend_apply) {
is_backend = false;
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
llama_sampler_chain_add(result->chain, smpl);
}
return result;
@ -336,7 +326,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
llama_sampler_free(gsmpl->chain_backend);
delete gsmpl;
}
@ -360,13 +349,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}
@ -415,8 +403,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) {
return gsmpl->chain_backend;
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
@ -424,11 +412,13 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// return that token id directly.
{
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
return id;
}
}
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@ -556,16 +546,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
}
std::string common_sampler_print(const struct common_sampler * gsmpl) {
std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i);
result += std::string("-> *") + llama_sampler_name(smpl) + " ";
}
std::string result = "logits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
result += std::string("-> ");
result += std::string(llama_sampler_name(smpl)) + " ";
}
return result;

View File

@ -48,7 +48,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation:
//

View File

@ -369,7 +369,8 @@ extern "C" {
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive)
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
};
@ -1193,21 +1194,27 @@ extern "C" {
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
// backend sampling interface
void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
// backend sampling interface:
// return true if the backend supports all ops needed by the sampler
// note: call once per sampler
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
// call after .backend_accept()
void (*backend_accept)(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_tensor * selected_token);
// call after .backend_init()
void (*backend_apply)(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data);
// call before .backend_apply()
void (*backend_set_input)(struct llama_sampler * smpl);
};

View File

@ -68,6 +68,8 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i];
// TODO: assert this is a llama_sampler_chain instance
if (set_sampler(config.seq_id, config.sampler)) {
const int n_samplers = llama_sampler_chain_n(config.sampler);

View File

@ -504,11 +504,13 @@ static void llama_sampler_empty_free(struct llama_sampler * smpl) {
delete (llama_sampler_empty *) smpl->ctx;
}
static void llama_sampler_empty_backend_init(
static bool llama_sampler_empty_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
return true;
}
static void llama_sampler_empty_backend_accept(
@ -559,6 +561,43 @@ struct llama_sampler * llama_sampler_init_empty(const char * name) {
);
}
// common backend sampler functionality
//
// +name : means that the sampler is support and will run on the backend
// -name : means that a ggml operator is not supported by the backend
//
struct llama_sampler_backend {
llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
const char * get_name() {
if (!is_init) {
return name.c_str();
}
if (support) {
name_ext = "+" + name;
} else {
name_ext = "-" + name;
}
return name_ext.c_str();
}
void init(bool support) {
GGML_ASSERT(this->is_init == false);
this->is_init = true;
this->support = support;
}
private:
std::string name;
std::string name_ext;
bool is_init;
bool support;
};
// sampler chain
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@ -570,8 +609,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
time_meas tm(chain->t_sample_us, chain->params.no_perf);
for (auto * smpl : chain->samplers) {
llama_sampler_accept(smpl, token);
for (auto & smpl : chain->samplers) {
llama_sampler_accept(smpl.ptr, token);
}
chain->n_sample++;
@ -582,20 +621,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
time_meas tm(chain->t_sample_us, chain->params.no_perf);
for (auto * smpl : chain->samplers) {
if (smpl->iface->apply == nullptr) {
bool is_backend = chain->is_init;
for (auto & smpl : chain->samplers) {
if (is_backend && smpl.is_backend) {
continue;
}
llama_sampler_apply(smpl, cur_p);
is_backend = false;
if (smpl.ptr->iface->apply == nullptr) {
continue;
}
llama_sampler_apply(smpl.ptr, cur_p);
}
}
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_reset(smpl);
for (auto & smpl : chain->samplers) {
llama_sampler_reset(smpl.ptr);
}
}
@ -604,8 +651,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
auto * result = llama_sampler_chain_init(chain_src->params);
for (auto * smpl : chain_src->samplers) {
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
for (const auto & smpl : chain_src->samplers) {
llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
}
return result;
@ -614,23 +661,44 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_free(smpl);
for (auto & smpl : chain->samplers) {
llama_sampler_free(smpl.ptr);
}
delete chain;
}
static void llama_sampler_chain_backend_init(
static bool llama_sampler_chain_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
if (smpl->iface->backend_init) {
smpl->iface->backend_init(smpl,buft);
GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
chain->is_init = true;
bool res = true;
for (auto & smpl : chain->samplers) {
bool res_cur = true;
// to be able to run a sampler on the backend, it has to:
// - have the .backend_init() API implemented
// - return true during .backend_init()
if (smpl.ptr->iface->backend_init) {
if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
res_cur = false;
}
} else {
res_cur = false;
}
smpl.is_backend = res_cur;
res = res && res_cur;
}
return res;
}
static void llama_sampler_chain_backend_accept(
@ -640,9 +708,13 @@ static void llama_sampler_chain_backend_accept(
struct ggml_tensor * selected_token) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
if (smpl->iface->backend_accept) {
smpl->iface->backend_accept(smpl, ctx, gf, selected_token);
for (auto & smpl : chain->samplers) {
if (!smpl.is_backend) {
break;
}
if (smpl.ptr->iface->backend_accept) {
smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
}
}
}
@ -654,9 +726,15 @@ static void llama_sampler_chain_backend_apply(
struct llama_sampler_data * data) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
if (smpl->iface->backend_apply) {
smpl->iface->backend_apply(smpl, ctx, gf, data);
GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
for (auto & smpl : chain->samplers) {
if (!smpl.is_backend) {
break;
}
if (smpl.ptr->iface->backend_apply) {
smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
}
}
}
@ -664,9 +742,13 @@ static void llama_sampler_chain_backend_apply(
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
if (smpl->iface->backend_set_input) {
smpl->iface->backend_set_input(smpl);
for (auto & smpl : chain->samplers) {
if (!smpl.is_backend) {
break;
}
if (smpl.ptr->iface->backend_set_input) {
smpl.ptr->iface->backend_set_input(smpl.ptr);
}
}
}
@ -689,6 +771,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .is_init = */ false,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
@ -698,7 +781,10 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
auto * p = (llama_sampler_chain *) chain->ctx;
p->samplers.push_back(smpl);
p->samplers.push_back({
/* .is_backend = */ false,
/* .ptr = */ smpl,
});
}
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
@ -708,7 +794,7 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai
return nullptr;
}
return p->samplers[i];
return p->samplers[i].ptr;
}
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
@ -718,7 +804,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
return nullptr;
}
auto * result = p->samplers[i];
auto * result = p->samplers[i].ptr;
p->samplers.erase(p->samplers.begin() + i);
return result;
@ -749,6 +835,15 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
}
}
static bool llama_sampler_greedy_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
return true;
}
static void llama_sampler_greedy_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
@ -768,7 +863,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
/* .reset = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
/* .backend_init = */ nullptr,
/* .backend_init = */ llama_sampler_greedy_backend_init,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
/* .backend_set_input = */ nullptr,
@ -783,23 +878,22 @@ struct llama_sampler * llama_sampler_init_greedy() {
// dist
struct llama_sampler_dist {
struct llama_sampler_dist : public llama_sampler_backend {
const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
// backend input
struct ggml_tensor * inp_uniform;
ggml_context_ptr inp_ctx;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
return "dist";
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -898,13 +992,61 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx;
}
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
static bool llama_sampler_dist_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
bool res = true;
// determine backend support
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error(format("failed to create ggml context"));
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
ggml_tensor * op = ggml_cumsum(ctx, probs);
auto * device = ggml_backend_buft_get_device(buft);
GGML_ASSERT(device);
if (!ggml_backend_dev_supports_op(device, op)) {
res = false;
}
sctx->init(res);
}
if (res) {
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name(sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
}
return res;
}
static void llama_sampler_dist_backend_apply(
@ -919,10 +1061,6 @@ static void llama_sampler_dist_backend_apply(
ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
ggml_set_name(cumsum, "cumsum");
// The uniform tensor has a random value and we subtract this tensor with
@ -965,29 +1103,13 @@ static void llama_sampler_dist_backend_apply(
data->sampled = sampled_token;
}
static void llama_sampler_dist_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
sctx->device = ggml_backend_buft_get_device(buft);
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name(sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
static struct llama_sampler_i llama_sampler_dist_i = {
@ -1008,10 +1130,10 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
("dist"),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .device = */ nullptr,
/* .inp_uniform = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
@ -1021,15 +1143,13 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
// top-k
struct llama_sampler_top_k {
struct llama_sampler_top_k : public llama_sampler_backend {
const int32_t k;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
return "top-k";
static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1046,11 +1166,42 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_k *) smpl->ctx;
}
static void llama_sampler_top_k_backend_init(
static bool llama_sampler_top_k_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
ctx_data->device = ggml_backend_buft_get_device(buft);
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
bool res = true;
// determine backend support
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error(format("failed to create ggml context"));
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
auto * device = ggml_backend_buft_get_device(buft);
GGML_ASSERT(device);
if (!ggml_backend_dev_supports_op(device, op)) {
res = false;
}
sctx->init(res);
}
return res;
}
static void llama_sampler_top_k_backend_apply(
@ -1058,26 +1209,17 @@ static void llama_sampler_top_k_backend_apply(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, ctx_data->k);
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
ggml_set_name(top_k, "top_k");
// top_k is a view of argsort - check if backend supports the underlying argsort operation
// by checking the source tensor (which is the argsort result)
if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
data->candidates = top_k;
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
data->candidates = top_k;
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
GGML_UNUSED(gf);
}
@ -1099,29 +1241,30 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
const bool is_empty = (k <= 0);
if (is_empty) {
return llama_sampler_init_empty("top-k?");
return llama_sampler_init_empty("?top-k");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
/* .device = */ nullptr,
("top-k"),
/* .k = */ k,
}
);
}
// top-p
struct llama_sampler_top_p {
struct llama_sampler_top_p : public llama_sampler_backend {
const float p;
const size_t min_keep;
std::vector<llama_token_data> buf_sort;
};
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
return "top-p";
static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1188,11 +1331,15 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_p *) smpl->ctx;
}
static void llama_sampler_top_p_backend_init(
static bool llama_sampler_top_p_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
sctx->init(true);
return true;
}
static void llama_sampler_top_p_backend_apply(
@ -1287,12 +1434,13 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
const bool is_empty = p >= 1.0f;
if (is_empty) {
return llama_sampler_init_empty("top-p?");
return llama_sampler_init_empty("?top-p");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
("top-p"),
/* .p = */ p,
/* .min_keep = */ min_keep,
/* .buf_sort = */ {},
@ -1302,13 +1450,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
// min-p
struct llama_sampler_min_p {
struct llama_sampler_min_p : public llama_sampler_backend {
const float p;
const size_t min_keep;
};
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
return "min-p";
static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1374,11 +1523,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_min_p *) smpl->ctx;
}
static void llama_sampler_min_p_backend_init(
static bool llama_sampler_min_p_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
sctx->init(true);
return true;
}
static void llama_sampler_min_p_backend_apply(
@ -1441,12 +1595,13 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
const bool is_empty = (p <= 0.0f);
if (is_empty) {
return llama_sampler_init_empty("min-p?");
return llama_sampler_init_empty("?min-p");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
("min-p"),
/* .p = */ p,
/* .min_keep = */ min_keep,
}
@ -1550,7 +1705,7 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
const bool is_empty = (p >= 1.0f);
if (is_empty) {
return llama_sampler_init_empty("typical?");
return llama_sampler_init_empty("?typical");
}
return llama_sampler_init(
@ -1564,12 +1719,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
// temp
struct llama_sampler_temp {
struct llama_sampler_temp : public llama_sampler_backend {
const float temp;
};
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
return "temp";
static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_temp *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1616,13 +1772,25 @@ static void llama_sampler_backend_temp_sampling(
GGML_UNUSED(gf);
}
static bool llama_sampler_temp_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_temp *) smpl->ctx;
sctx->init(true);
return true;
}
static void llama_sampler_temp_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
auto * sctx = (llama_sampler_temp *) smpl->ctx;
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
}
static struct llama_sampler_i llama_sampler_temp_i = {
@ -1632,7 +1800,7 @@ static struct llama_sampler_i llama_sampler_temp_i = {
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_clone,
/* .free = */ llama_sampler_temp_free,
/* .backend_init = */ nullptr,
/* .backend_init = */ llama_sampler_temp_backend_init,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_temp_backend_apply,
/* .backend_set_input = */ nullptr,
@ -1642,12 +1810,13 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
const bool is_empty = temp == 1.0f;
if (is_empty) {
return llama_sampler_init_empty("temp?");
return llama_sampler_init_empty("?temp");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
("temp"),
/*.temp = */ temp,
}
);
@ -1655,14 +1824,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
// temp-ext
struct llama_sampler_temp_ext {
struct llama_sampler_temp_ext : public llama_sampler_backend {
const float temp;
const float delta;
const float exponent;
};
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
return "temp-ext";
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1745,22 +1915,34 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp_ext *) smpl->ctx;
}
static bool llama_sampler_temp_ext_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
sctx->init(true);
return true;
}
static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
// Revert to standard temperature scaling if delta or temp are non-positive.
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
return;
}
// Calculate min_temp, max_temp, and max_entropy.
const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta);
const float max_temp = ctx_data->temp + ctx_data->delta;
const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
const float max_temp = sctx->temp + sctx->delta;
const float max_entropy = logf(data->logits->ne[0]);
// Calculate the probabilities.
@ -1791,7 +1973,7 @@ static void llama_sampler_temp_ext_backend_apply(
// Calculate powf(normalized_entropy, exponent) as
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent);
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
// With pow_entropy computed we can now compute dyn_temp, scaling by
// (max_temp - min_temp) and then adding min_temp.
@ -1815,7 +1997,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_ext_clone,
/* .free = */ llama_sampler_temp_ext_free,
/* .backend_init = */ nullptr,
/* .backend_init = */ llama_sampler_temp_ext_backend_init,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
/* .backend_set_input = */ nullptr,
@ -1825,12 +2007,13 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
const bool is_empty = temp == 1.0f && delta <= 0.0f;
if (is_empty) {
return llama_sampler_init_empty("temp-ext?");
return llama_sampler_init_empty("?temp-ext");
}
auto * res = llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
("temp-ext"),
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
@ -1931,7 +2114,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
const bool is_empty = (p <= 0.0f || t > 0.5f);
if (is_empty) {
return llama_sampler_init_empty("xtc?");
return llama_sampler_init_empty("?xtc");
}
const auto seed_cur = get_rng_seed(seed);
@ -2492,7 +2675,7 @@ struct llama_sampler * llama_sampler_init_penalties(
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
if (is_empty) {
return llama_sampler_init_empty("penalties?");
return llama_sampler_init_empty("?penalties");
}
return llama_sampler_init(
@ -2585,7 +2768,7 @@ struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
const bool is_empty = (n <= 0.0f);
if (is_empty) {
return llama_sampler_init_empty("top-n-sigma?");
return llama_sampler_init_empty("?top-n-sigma");
}
return llama_sampler_init(
@ -2930,7 +3113,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
if (!dry_enabled) {
return llama_sampler_init_empty("dry?");
return llama_sampler_init_empty("?dry");
}
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
@ -3003,7 +3186,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
// logit-bias
struct llama_sampler_logit_bias {
struct llama_sampler_logit_bias : public llama_sampler_backend {
const int32_t n_vocab;
const std::vector<llama_logit_bias> logit_bias;
@ -3016,8 +3199,9 @@ struct llama_sampler_logit_bias {
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
return "logit-bias";
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
return ctx->get_name();
}
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -3097,13 +3281,15 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
}
static void llama_sampler_logit_bias_backend_init(
static bool llama_sampler_logit_bias_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
sctx->init(true);
if (sctx->logit_bias.empty()) {
return;
return true;
}
ggml_init_params params = {
@ -3120,6 +3306,8 @@ static void llama_sampler_logit_bias_backend_init(
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
return true;
}
static struct llama_sampler_i llama_sampler_logit_bias_i = {
@ -3142,12 +3330,13 @@ struct llama_sampler * llama_sampler_init_logit_bias(
const bool is_empty = n_logit_bias <= 0;
if (is_empty) {
return llama_sampler_init_empty("logit-bias?");
return llama_sampler_init_empty("?logit-bias");
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
("logit-bias"),
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
@ -3407,7 +3596,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
if (smpl->iface == &llama_sampler_chain_i) {
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
const uint32_t seed = llama_sampler_get_seed(*it);
const uint32_t seed = llama_sampler_get_seed(it->ptr);
if (seed != LLAMA_DEFAULT_SEED) {
return seed;
}

View File

@ -14,7 +14,16 @@ struct llama_grammar;
struct llama_sampler_chain {
llama_sampler_chain_params params;
std::vector<struct llama_sampler *> samplers;
// has .backend_init() been called?
bool is_init = false;
struct info {
bool is_backend;
llama_sampler * ptr;
};
std::vector<info> samplers;
// timing

View File

@ -1028,10 +1028,9 @@ struct server_context_impl {
return false;
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get());
llama_set_sampler(ctx, slot.id, backend_chain);
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
}
// initialize draft batch