From 6958d41366873bd6a9e22be16ccdd5c4724d3d5a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 Dec 2025 17:29:08 +0200 Subject: [PATCH] sampling : check backend support during init --- common/common.cpp | 3 +- common/sampling.cpp | 54 ++-- common/sampling.h | 2 +- include/llama.h | 13 +- src/llama-context.cpp | 2 + src/llama-sampling.cpp | 457 ++++++++++++++++++++++---------- src/llama-sampling.h | 11 +- tools/server/server-context.cpp | 5 +- 8 files changed, 369 insertions(+), 178 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 5982c549ce..f52c41af76 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1098,8 +1098,7 @@ common_init_result::common_init_result(common_params & params) : for (int i = 0; i < (int) cparams.n_seq_max; ++i) { pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); - llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get()); - pimpl->samplers_seq_config[i] = { i, backend_chain }; + pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; } cparams.samplers = pimpl->samplers_seq_config.data(); diff --git a/common/sampling.cpp b/common/sampling.cpp index b7dfed547b..3941b5f574 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -106,7 +106,6 @@ struct common_sampler { struct llama_sampler * grmr; struct llama_sampler * chain; - struct llama_sampler * chain_backend; ring_buffer prev; @@ -119,7 +118,6 @@ struct common_sampler { llama_sampler_reset(grmr); llama_sampler_reset(chain); - llama_sampler_reset(chain_backend); } void set_logits(struct llama_context * ctx, int idx) { @@ -247,13 +245,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } auto * result = new common_sampler { - /* .params = */ params, - /* .grmr = */ grmr, - /* .chain = */ llama_sampler_chain_init(lparams), - /* .chain_backend = */ llama_sampler_chain_init(lparams), - /* .prev = */ ring_buffer(std::max(32, params.n_prev)), - /* .cur = */ {}, - /* .cur_p = */ {}, + /* .params = */ params, + /* .grmr = */ grmr, + /* .chain = */ llama_sampler_chain_init(lparams), + /* .prev = */ ring_buffer(std::max(32, params.n_prev)), + /* .cur = */ {}, + /* .cur_p = */ {}, }; std::vector samplers; @@ -318,15 +315,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ASSERT(false && "unknown mirostat version"); } - bool is_backend = params.backend_sampling; - - // split in two chains: backend -> CPU for (auto * smpl : samplers) { - if (!smpl->iface->backend_apply) { - is_backend = false; - } - - llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl); + llama_sampler_chain_add(result->chain, smpl); } return result; @@ -336,7 +326,6 @@ void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); llama_sampler_free(gsmpl->chain); - llama_sampler_free(gsmpl->chain_backend); delete gsmpl; } @@ -360,13 +349,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, + /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } @@ -415,8 +403,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } -struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) { - return gsmpl->chain_backend; +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + return gsmpl->chain; } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { @@ -424,11 +412,13 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // return that token id directly. { const llama_token id = llama_get_sampled_token_ith(ctx, idx); + if (id != LLAMA_TOKEN_NULL) { LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); return id; } } + llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations @@ -556,16 +546,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) { } std::string common_sampler_print(const struct common_sampler * gsmpl) { - std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits "; - - for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) { - const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i); - result += std::string("-> *") + llama_sampler_name(smpl) + " "; - } + std::string result = "logits "; for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string("-> ") + llama_sampler_name(smpl) + " "; + result += std::string("-> "); + result += std::string(llama_sampler_name(smpl)) + " "; } return result; diff --git a/common/sampling.h b/common/sampling.h index 04b56dbbed..c7101032f2 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -48,7 +48,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); -struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl); +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // extended sampling implementation: // diff --git a/include/llama.h b/include/llama.h index f6926b6063..e01d06766d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -369,7 +369,8 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - // backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive) + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; }; @@ -1193,21 +1194,27 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - // backend sampling interface - void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + // backend sampling interface: + // return true if the backend supports all ops needed by the sampler + // note: call once per sampler + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + + // call after .backend_accept() void (*backend_accept)( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_tensor * selected_token); + // call after .backend_init() void (*backend_apply)( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data); + // call before .backend_apply() void (*backend_set_input)(struct llama_sampler * smpl); }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 78f12011c4..e04e461858 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -68,6 +68,8 @@ llama_context::llama_context( for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; + // TODO: assert this is a llama_sampler_chain instance + if (set_sampler(config.seq_id, config.sampler)) { const int n_samplers = llama_sampler_chain_n(config.sampler); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 88008e8c45..e910b6e14e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -504,11 +504,13 @@ static void llama_sampler_empty_free(struct llama_sampler * smpl) { delete (llama_sampler_empty *) smpl->ctx; } -static void llama_sampler_empty_backend_init( +static bool llama_sampler_empty_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { GGML_UNUSED(smpl); GGML_UNUSED(buft); + + return true; } static void llama_sampler_empty_backend_accept( @@ -559,6 +561,43 @@ struct llama_sampler * llama_sampler_init_empty(const char * name) { ); } +// common backend sampler functionality +// +// +name : means that the sampler is support and will run on the backend +// -name : means that a ggml operator is not supported by the backend +// +struct llama_sampler_backend { + llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} + + const char * get_name() { + if (!is_init) { + return name.c_str(); + } + + if (support) { + name_ext = "+" + name; + } else { + name_ext = "-" + name; + } + + return name_ext.c_str(); + } + + void init(bool support) { + GGML_ASSERT(this->is_init == false); + + this->is_init = true; + this->support = support; + } + +private: + std::string name; + std::string name_ext; + + bool is_init; + bool support; +}; + // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -570,8 +609,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); + for (auto & smpl : chain->samplers) { + llama_sampler_accept(smpl.ptr, token); } chain->n_sample++; @@ -582,20 +621,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - if (smpl->iface->apply == nullptr) { + bool is_backend = chain->is_init; + + for (auto & smpl : chain->samplers) { + if (is_backend && smpl.is_backend) { continue; } - llama_sampler_apply(smpl, cur_p); + is_backend = false; + + if (smpl.ptr->iface->apply == nullptr) { + continue; + } + + llama_sampler_apply(smpl.ptr, cur_p); } } static void llama_sampler_chain_reset(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_reset(smpl.ptr); } } @@ -604,8 +651,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl auto * result = llama_sampler_chain_init(chain_src->params); - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + for (const auto & smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); } return result; @@ -614,23 +661,44 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl static void llama_sampler_chain_free(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_free(smpl.ptr); } delete chain; } -static void llama_sampler_chain_backend_init( +static bool llama_sampler_chain_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - if (smpl->iface->backend_init) { - smpl->iface->backend_init(smpl,buft); + GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + + chain->is_init = true; + + bool res = true; + + for (auto & smpl : chain->samplers) { + bool res_cur = true; + + // to be able to run a sampler on the backend, it has to: + // - have the .backend_init() API implemented + // - return true during .backend_init() + if (smpl.ptr->iface->backend_init) { + if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { + res_cur = false; + } + } else { + res_cur = false; } + + smpl.is_backend = res_cur; + + res = res && res_cur; } + + return res; } static void llama_sampler_chain_backend_accept( @@ -640,9 +708,13 @@ static void llama_sampler_chain_backend_accept( struct ggml_tensor * selected_token) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - if (smpl->iface->backend_accept) { - smpl->iface->backend_accept(smpl, ctx, gf, selected_token); + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_accept) { + smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); } } } @@ -654,9 +726,15 @@ static void llama_sampler_chain_backend_apply( struct llama_sampler_data * data) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - if (smpl->iface->backend_apply) { - smpl->iface->backend_apply(smpl, ctx, gf, data); + GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_apply) { + smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); } } } @@ -664,9 +742,13 @@ static void llama_sampler_chain_backend_apply( static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - if (smpl->iface->backend_set_input) { - smpl->iface->backend_set_input(smpl); + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_set_input) { + smpl.ptr->iface->backend_set_input(smpl.ptr); } } } @@ -689,6 +771,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .iface = */ &llama_sampler_chain_i, /* .ctx = */ new llama_sampler_chain { /* .params = */ params, + /* .is_init = */ false, /* .samplers = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, @@ -698,7 +781,10 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; - p->samplers.push_back(smpl); + p->samplers.push_back({ + /* .is_backend = */ false, + /* .ptr = */ smpl, + }); } struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { @@ -708,7 +794,7 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai return nullptr; } - return p->samplers[i]; + return p->samplers[i].ptr; } struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { @@ -718,7 +804,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, return nullptr; } - auto * result = p->samplers[i]; + auto * result = p->samplers[i].ptr; p->samplers.erase(p->samplers.begin() + i); return result; @@ -749,6 +835,15 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static bool llama_sampler_greedy_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); + + return true; +} + static void llama_sampler_greedy_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, @@ -768,7 +863,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = { /* .reset = */ nullptr, /* .clone = */ nullptr, /* .free = */ nullptr, - /* .backend_init = */ nullptr, + /* .backend_init = */ llama_sampler_greedy_backend_init, /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_greedy_backend_apply, /* .backend_set_input = */ nullptr, @@ -783,23 +878,22 @@ struct llama_sampler * llama_sampler_init_greedy() { // dist -struct llama_sampler_dist { +struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; std::mt19937 rng; - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; - + // backend input struct ggml_tensor * inp_uniform; - ggml_context_ptr inp_ctx; + ggml_context_ptr inp_ctx; ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { - return "dist"; +static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -898,13 +992,61 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } -static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { +static bool llama_sampler_dist_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - GGML_ASSERT(sctx->inp_uniform != nullptr); - std::uniform_real_distribution dist(0.0f, 1.0f); - const float rnd = dist(sctx->rng); - ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); + bool res = true; + + // determine backend support + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024); + ggml_tensor * op = ggml_cumsum(ctx, probs); + + auto * device = ggml_backend_buft_get_device(buft); + GGML_ASSERT(device); + + if (!ggml_backend_dev_supports_op(device, op)) { + res = false; + } + + sctx->init(res); + } + + if (res) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_dist_backend_set_input after this graph is built. + sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); + ggml_set_name(sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + } + + return res; } static void llama_sampler_dist_backend_apply( @@ -919,10 +1061,6 @@ static void llama_sampler_dist_backend_apply( ggml_set_name(probs, "dist_probs"); struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); - if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) { - fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n"); - fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); - } ggml_set_name(cumsum, "cumsum"); // The uniform tensor has a random value and we subtract this tensor with @@ -965,29 +1103,13 @@ static void llama_sampler_dist_backend_apply( data->sampled = sampled_token; } -static void llama_sampler_dist_backend_init( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); - sctx->device = ggml_backend_buft_get_device(buft); - - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name(sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); } static struct llama_sampler_i llama_sampler_dist_i = { @@ -1008,10 +1130,10 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { + ("dist"), /* .seed = */ seed, /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), - /* .device = */ nullptr, /* .inp_uniform = */ nullptr, /* .inp_ctx = */ nullptr, /* .inp_buf = */ nullptr, @@ -1021,15 +1143,13 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { // top-k -struct llama_sampler_top_k { +struct llama_sampler_top_k : public llama_sampler_backend { const int32_t k; - - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; }; -static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { - return "top-k"; +static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1046,11 +1166,42 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } -static void llama_sampler_top_k_backend_init( +static bool llama_sampler_top_k_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; - ctx_data->device = ggml_backend_buft_get_device(buft); + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + bool res = true; + + // determine backend support + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024); + ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k); + + auto * device = ggml_backend_buft_get_device(buft); + GGML_ASSERT(device); + + if (!ggml_backend_dev_supports_op(device, op)) { + res = false; + } + + sctx->init(res); + } + + return res; } static void llama_sampler_top_k_backend_apply( @@ -1058,26 +1209,17 @@ static void llama_sampler_top_k_backend_apply( struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; - auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; - - struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, ctx_data->k); + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); ggml_set_name(top_k, "top_k"); - // top_k is a view of argsort - check if backend supports the underlying argsort operation - // by checking the source tensor (which is the argsort result) - if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) { - fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n"); - fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); - } - - data->candidates = top_k; - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); ggml_set_name(top_k_rows, "top_k_rows"); - data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); + data->candidates = top_k; + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); GGML_UNUSED(gf); } @@ -1099,29 +1241,30 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { const bool is_empty = (k <= 0); if (is_empty) { - return llama_sampler_init_empty("top-k?"); + return llama_sampler_init_empty("?top-k"); } return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { - /* .k = */ k, - /* .device = */ nullptr, + ("top-k"), + /* .k = */ k, } ); } // top-p -struct llama_sampler_top_p { +struct llama_sampler_top_p : public llama_sampler_backend { const float p; const size_t min_keep; std::vector buf_sort; }; -static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { - return "top-p"; +static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1188,11 +1331,15 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } -static void llama_sampler_top_p_backend_init( +static bool llama_sampler_top_p_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(smpl); GGML_UNUSED(buft); + + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + sctx->init(true); + + return true; } static void llama_sampler_top_p_backend_apply( @@ -1287,12 +1434,13 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { const bool is_empty = p >= 1.0f; if (is_empty) { - return llama_sampler_init_empty("top-p?"); + return llama_sampler_init_empty("?top-p"); } return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { + ("top-p"), /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, @@ -1302,13 +1450,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_min_p { +struct llama_sampler_min_p : public llama_sampler_backend { const float p; const size_t min_keep; }; -static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { - return "min-p"; +static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1374,11 +1523,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } -static void llama_sampler_min_p_backend_init( +static bool llama_sampler_min_p_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(smpl); GGML_UNUSED(buft); + + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + sctx->init(true); + + return true; } static void llama_sampler_min_p_backend_apply( @@ -1441,12 +1595,13 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { const bool is_empty = (p <= 0.0f); if (is_empty) { - return llama_sampler_init_empty("min-p?"); + return llama_sampler_init_empty("?min-p"); } return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { + ("min-p"), /* .p = */ p, /* .min_keep = */ min_keep, } @@ -1550,7 +1705,7 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { const bool is_empty = (p >= 1.0f); if (is_empty) { - return llama_sampler_init_empty("typical?"); + return llama_sampler_init_empty("?typical"); } return llama_sampler_init( @@ -1564,12 +1719,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_temp { +struct llama_sampler_temp : public llama_sampler_backend { const float temp; }; -static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { - return "temp"; +static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1616,13 +1772,25 @@ static void llama_sampler_backend_temp_sampling( GGML_UNUSED(gf); } +static bool llama_sampler_temp_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + + auto * sctx = (llama_sampler_temp *) smpl->ctx; + + sctx->init(true); + + return true; +} + static void llama_sampler_temp_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data) { - auto * ctx_data = (llama_sampler_temp *) smpl->ctx; - llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp); + auto * sctx = (llama_sampler_temp *) smpl->ctx; + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); } static struct llama_sampler_i llama_sampler_temp_i = { @@ -1632,7 +1800,7 @@ static struct llama_sampler_i llama_sampler_temp_i = { /* .reset = */ nullptr, /* .clone = */ llama_sampler_temp_clone, /* .free = */ llama_sampler_temp_free, - /* .backend_init = */ nullptr, + /* .backend_init = */ llama_sampler_temp_backend_init, /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_temp_backend_apply, /* .backend_set_input = */ nullptr, @@ -1642,12 +1810,13 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { const bool is_empty = temp == 1.0f; if (is_empty) { - return llama_sampler_init_empty("temp?"); + return llama_sampler_init_empty("?temp"); } return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { + ("temp"), /*.temp = */ temp, } ); @@ -1655,14 +1824,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_temp_ext { +struct llama_sampler_temp_ext : public llama_sampler_backend { const float temp; const float delta; const float exponent; }; -static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { - return "temp-ext"; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1745,22 +1915,34 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static bool llama_sampler_temp_ext_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + sctx->init(true); + + return true; +} + static void llama_sampler_temp_ext_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data) { - auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx; + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; // Revert to standard temperature scaling if delta or temp are non-positive. - if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) { - llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp); + if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); return; } // Calculate min_temp, max_temp, and max_entropy. - const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta); - const float max_temp = ctx_data->temp + ctx_data->delta; + const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); + const float max_temp = sctx->temp + sctx->delta; const float max_entropy = logf(data->logits->ne[0]); // Calculate the probabilities. @@ -1791,7 +1973,7 @@ static void llama_sampler_temp_ext_backend_apply( // Calculate powf(normalized_entropy, exponent) as // norm_entropy^exponent = exp(exponent * log(norm_entropy)) struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); - struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); // With pow_entropy computed we can now compute dyn_temp, scaling by // (max_temp - min_temp) and then adding min_temp. @@ -1815,7 +1997,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .reset = */ nullptr, /* .clone = */ llama_sampler_temp_ext_clone, /* .free = */ llama_sampler_temp_ext_free, - /* .backend_init = */ nullptr, + /* .backend_init = */ llama_sampler_temp_ext_backend_init, /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, /* .backend_set_input = */ nullptr, @@ -1825,12 +2007,13 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa const bool is_empty = temp == 1.0f && delta <= 0.0f; if (is_empty) { - return llama_sampler_init_empty("temp-ext?"); + return llama_sampler_init_empty("?temp-ext"); } auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { + ("temp-ext"), /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, @@ -1931,7 +2114,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, const bool is_empty = (p <= 0.0f || t > 0.5f); if (is_empty) { - return llama_sampler_init_empty("xtc?"); + return llama_sampler_init_empty("?xtc"); } const auto seed_cur = get_rng_seed(seed); @@ -2492,7 +2675,7 @@ struct llama_sampler * llama_sampler_init_penalties( const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); if (is_empty) { - return llama_sampler_init_empty("penalties?"); + return llama_sampler_init_empty("?penalties"); } return llama_sampler_init( @@ -2585,7 +2768,7 @@ struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { const bool is_empty = (n <= 0.0f); if (is_empty) { - return llama_sampler_init_empty("top-n-sigma?"); + return llama_sampler_init_empty("?top-n-sigma"); } return llama_sampler_init( @@ -2930,7 +3113,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); if (!dry_enabled) { - return llama_sampler_init_empty("dry?"); + return llama_sampler_init_empty("?dry"); } if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { @@ -3003,7 +3186,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // logit-bias -struct llama_sampler_logit_bias { +struct llama_sampler_logit_bias : public llama_sampler_backend { const int32_t n_vocab; const std::vector logit_bias; @@ -3016,8 +3199,9 @@ struct llama_sampler_logit_bias { ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { - return "logit-bias"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + return ctx->get_name(); } static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -3097,13 +3281,15 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); } -static void llama_sampler_logit_bias_backend_init( +static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + sctx->init(true); + if (sctx->logit_bias.empty()) { - return; + return true; } ggml_init_params params = { @@ -3120,6 +3306,8 @@ static void llama_sampler_logit_bias_backend_init( // Allocate all tensors from our context to the backend sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + return true; } static struct llama_sampler_i llama_sampler_logit_bias_i = { @@ -3142,12 +3330,13 @@ struct llama_sampler * llama_sampler_init_logit_bias( const bool is_empty = n_logit_bias <= 0; if (is_empty) { - return llama_sampler_init_empty("logit-bias?"); + return llama_sampler_init_empty("?logit-bias"); } return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { + ("logit-bias"), /* .n_vocab = */ n_vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), /* .to_search = */ {}, @@ -3407,7 +3596,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { if (smpl->iface == &llama_sampler_chain_i) { const auto * ctx = (const llama_sampler_chain *) smpl->ctx; for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { - const uint32_t seed = llama_sampler_get_seed(*it); + const uint32_t seed = llama_sampler_get_seed(it->ptr); if (seed != LLAMA_DEFAULT_SEED) { return seed; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 80ea22ac35..18cae29ece 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -14,7 +14,16 @@ struct llama_grammar; struct llama_sampler_chain { llama_sampler_chain_params params; - std::vector samplers; + // has .backend_init() been called? + bool is_init = false; + + struct info { + bool is_backend; + + llama_sampler * ptr; + }; + + std::vector samplers; // timing diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0387ca991f..3919401b44 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1028,10 +1028,9 @@ struct server_context_impl { return false; } - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); + llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); - llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get()); - llama_set_sampler(ctx, slot.id, backend_chain); + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } // initialize draft batch