llama : call backend_init once
This commit is contained in:
parent
d8d98bb4bb
commit
ff7b0bf632
|
|
@ -1217,7 +1217,7 @@ extern "C" {
|
|||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
LLAMA_API bool llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
|
|
|
|||
|
|
@ -68,14 +68,11 @@ llama_context::llama_context(
|
|||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||
if (n_samplers <= 0) {
|
||||
continue;
|
||||
if (set_backend_sampler(config.seq_id, config.sampler)) {
|
||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||
|
||||
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
||||
}
|
||||
|
||||
sampling.samplers[config.seq_id] = config.sampler;
|
||||
|
||||
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -912,14 +909,35 @@ void llama_context::set_warmup(bool value) {
|
|||
cparams.warmup = value;
|
||||
}
|
||||
|
||||
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
bool llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) {
|
||||
const bool can_offload =
|
||||
sampler &&
|
||||
sampler->iface->backend_init &&
|
||||
sampler->iface->backend_apply &&
|
||||
llama_sampler_chain_n(sampler) > 0;
|
||||
|
||||
if (sampler && can_offload) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
||||
sampler->iface->backend_init(sampler, buft);
|
||||
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
} else {
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sampler && !can_offload) {
|
||||
LLAMA_LOG_WARN("%s: sampler '%s' cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler));
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_context::set_adapter_lora(
|
||||
|
|
@ -1910,7 +1928,7 @@ llm_graph_params llama_context::graph_params(
|
|||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const {
|
||||
llm_graph_type gtype) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
|
|
@ -1919,7 +1937,6 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.gtype =*/ gtype,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.dev_out =*/ model.dev_output(),
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
|
|
@ -2980,8 +2997,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
ctx->set_backend_sampler(seq_id, smpl);
|
||||
bool llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
return ctx->set_backend_sampler(seq_id, smpl);
|
||||
}
|
||||
|
||||
llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) {
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
||||
|
||||
void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
bool set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
|
|
|
|||
|
|
@ -609,7 +609,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||
rope_type (hparams.rope_type),
|
||||
sched (params.sched),
|
||||
backend_cpu (params.backend_cpu),
|
||||
dev_out (params.dev_out),
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
|
|
@ -2075,8 +2074,6 @@ void llm_graph_context::build_sampling() const {
|
|||
|
||||
const int64_t n_vocab = logits_t->ne[0];
|
||||
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev_out);
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> active_samplers;
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
|
|
@ -2085,13 +2082,8 @@ void llm_graph_context::build_sampling() const {
|
|||
if (it == seq_to_logit_row.end()) {
|
||||
continue;
|
||||
}
|
||||
const int32_t row_idx = it->second;
|
||||
|
||||
// Allow GPU sampler to create input tensors by implementing init_ggml.
|
||||
// TODO: this should not be done here
|
||||
if (sampler->iface->backend_init != nullptr) {
|
||||
sampler->iface->backend_init(sampler, buft);
|
||||
}
|
||||
const int32_t row_idx = it->second;
|
||||
|
||||
active_samplers[seq_id] = sampler;
|
||||
|
||||
|
|
|
|||
|
|
@ -428,7 +428,6 @@ struct llm_graph_params {
|
|||
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
ggml_backend_dev_t dev_out;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
|
|
@ -617,8 +616,6 @@ struct llm_graph_context {
|
|||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
ggml_backend_dev_t dev_out;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
|
|
|
|||
Loading…
Reference in New Issue