From ff7b0bf6320db65df659a13784844d0dd380c6b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Nov 2025 23:09:53 +0200 Subject: [PATCH] llama : call backend_init once --- include/llama.h | 2 +- src/llama-context.cpp | 47 +++++++++++++++++++++++++++++-------------- src/llama-context.h | 2 +- src/llama-graph.cpp | 10 +-------- src/llama-graph.h | 3 --- 5 files changed, 35 insertions(+), 29 deletions(-) diff --git a/include/llama.h b/include/llama.h index 24cd5be4a5..57fe4bd127 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1217,7 +1217,7 @@ extern "C" { llama_sampler_context_t ctx; }; - LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + LLAMA_API bool llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2b4dc58a43..79d1d633d1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -68,14 +68,11 @@ llama_context::llama_context( for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; - const int n_samplers = llama_sampler_chain_n(config.sampler); - if (n_samplers <= 0) { - continue; + if (set_backend_sampler(config.seq_id, config.sampler)) { + const int n_samplers = llama_sampler_chain_n(config.sampler); + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); } - - sampling.samplers[config.seq_id] = config.sampler; - - LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); } } @@ -912,14 +909,35 @@ void llama_context::set_warmup(bool value) { cparams.warmup = value; } -void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { +bool llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); - if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) { + const bool can_offload = + sampler && + sampler->iface->backend_init && + sampler->iface->backend_apply && + llama_sampler_chain_n(sampler) > 0; + + if (sampler && can_offload) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); + sampler->iface->backend_init(sampler, buft); + sampling.samplers[seq_id] = sampler; - } else { - sampling.samplers.erase(seq_id); + + return true; } + + if (sampler && !can_offload) { + LLAMA_LOG_WARN("%s: sampler '%s' cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler)); + + sampling.samplers.erase(seq_id); + + return false; + } + + sampling.samplers.erase(seq_id); + + return true; } void llama_context::set_adapter_lora( @@ -1910,7 +1928,7 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1919,7 +1937,6 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.dev_out =*/ model.dev_output(), /*.cvec =*/ &cvec, /*.loras =*/ &loras, /*.mctx =*/ mctx, @@ -2980,8 +2997,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { - ctx->set_backend_sampler(seq_id, smpl); +bool llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + return ctx->set_backend_sampler(seq_id, smpl); } llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) { diff --git a/src/llama-context.h b/src/llama-context.h index 2940d337a8..9b568d98b8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -221,7 +221,7 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); - void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler); + bool set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler); private: llm_graph_params graph_params( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 3d88dcd296..f396feeded 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -609,7 +609,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : rope_type (hparams.rope_type), sched (params.sched), backend_cpu (params.backend_cpu), - dev_out (params.dev_out), cvec (params.cvec), loras (params.loras), mctx (params.mctx), @@ -2075,8 +2074,6 @@ void llm_graph_context::build_sampling() const { const int64_t n_vocab = logits_t->ne[0]; - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev_out); - std::unordered_map active_samplers; for (const auto & [seq_id, sampler] : samplers) { @@ -2085,13 +2082,8 @@ void llm_graph_context::build_sampling() const { if (it == seq_to_logit_row.end()) { continue; } - const int32_t row_idx = it->second; - // Allow GPU sampler to create input tensors by implementing init_ggml. - // TODO: this should not be done here - if (sampler->iface->backend_init != nullptr) { - sampler->iface->backend_init(sampler, buft); - } + const int32_t row_idx = it->second; active_samplers[seq_id] = sampler; diff --git a/src/llama-graph.h b/src/llama-graph.h index 6797d78a20..9090eca028 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -428,7 +428,6 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - ggml_backend_dev_t dev_out; const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; @@ -617,8 +616,6 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - ggml_backend_dev_t dev_out; - const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx;