llama : call backend_init once

This commit is contained in:
Georgi Gerganov 2025-11-29 23:09:53 +02:00
parent d8d98bb4bb
commit ff7b0bf632
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 35 additions and 29 deletions

View File

@ -1217,7 +1217,7 @@ extern "C" {
llama_sampler_context_t ctx;
};
LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
LLAMA_API bool llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);

View File

@ -68,14 +68,11 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i];
const int n_samplers = llama_sampler_chain_n(config.sampler);
if (n_samplers <= 0) {
continue;
if (set_backend_sampler(config.seq_id, config.sampler)) {
const int n_samplers = llama_sampler_chain_n(config.sampler);
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
}
sampling.samplers[config.seq_id] = config.sampler;
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
}
}
@ -912,14 +909,35 @@ void llama_context::set_warmup(bool value) {
cparams.warmup = value;
}
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
bool llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) {
const bool can_offload =
sampler &&
sampler->iface->backend_init &&
sampler->iface->backend_apply &&
llama_sampler_chain_n(sampler) > 0;
if (sampler && can_offload) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
sampler->iface->backend_init(sampler, buft);
sampling.samplers[seq_id] = sampler;
} else {
sampling.samplers.erase(seq_id);
return true;
}
if (sampler && !can_offload) {
LLAMA_LOG_WARN("%s: sampler '%s' cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler));
sampling.samplers.erase(seq_id);
return false;
}
sampling.samplers.erase(seq_id);
return true;
}
void llama_context::set_adapter_lora(
@ -1910,7 +1928,7 @@ llm_graph_params llama_context::graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
llm_graph_type gtype) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
@ -1919,7 +1937,6 @@ llm_graph_params llama_context::graph_params(
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.dev_out =*/ model.dev_output(),
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
@ -2980,8 +2997,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
ctx->set_backend_sampler(seq_id, smpl);
bool llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_backend_sampler(seq_id, smpl);
}
llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) {

View File

@ -221,7 +221,7 @@ public:
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
bool set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
private:
llm_graph_params graph_params(

View File

@ -609,7 +609,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
rope_type (hparams.rope_type),
sched (params.sched),
backend_cpu (params.backend_cpu),
dev_out (params.dev_out),
cvec (params.cvec),
loras (params.loras),
mctx (params.mctx),
@ -2075,8 +2074,6 @@ void llm_graph_context::build_sampling() const {
const int64_t n_vocab = logits_t->ne[0];
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev_out);
std::unordered_map<llama_seq_id, llama_sampler*> active_samplers;
for (const auto & [seq_id, sampler] : samplers) {
@ -2085,13 +2082,8 @@ void llm_graph_context::build_sampling() const {
if (it == seq_to_logit_row.end()) {
continue;
}
const int32_t row_idx = it->second;
// Allow GPU sampler to create input tensors by implementing init_ggml.
// TODO: this should not be done here
if (sampler->iface->backend_init != nullptr) {
sampler->iface->backend_init(sampler, buft);
}
const int32_t row_idx = it->second;
active_samplers[seq_id] = sampler;

View File

@ -428,7 +428,6 @@ struct llm_graph_params {
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
ggml_backend_dev_t dev_out;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
@ -617,8 +616,6 @@ struct llm_graph_context {
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
ggml_backend_dev_t dev_out;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;