diff --git a/common/common.cpp b/common/common.cpp index 744f0b4eeb..26250abb6c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) : pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; } - // TODO: temporarily gated behind a flag if (params.sampling.backend_sampling) { cparams.samplers = pimpl->samplers_seq_config.data(); cparams.n_samplers = pimpl->samplers_seq_config.size(); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 6875038770..a8c19a6aba 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -81,7 +81,6 @@ int main(int argc, char ** argv) { sampler_configs.push_back({ i, smpl }); } - // TODO: temporarily gated behind a flag if (params.sampling.backend_sampling) { ctx_params.samplers = sampler_configs.data(); ctx_params.n_samplers = sampler_configs.size(); diff --git a/include/llama.h b/include/llama.h index 1c17efb9fa..1b4ca42245 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1255,7 +1255,6 @@ extern "C" { // [EXPERIMENTAL] // attach a sampler to the context // note: prefer initializing the context with llama_context_params.samplers when possible - // note: changing the samplers of a context can cause graph reallocations and degraded performance LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d1b02ae71c..929678587e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -340,7 +340,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - reserve(); + sched_reserve(); if (!cparams.flash_attn) { if (ggml_is_quantized(params.type_v)) { @@ -380,7 +380,13 @@ llama_context::~llama_context() { ggml_opt_free(opt_ctx); } -void llama_context::reserve() { +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } + + sched_need_reserve = false; + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); synchronize(); @@ -408,10 +414,8 @@ void llama_context::reserve() { } } - cross.v_embd.clear(); - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; + const int n_outputs = n_seqs; LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -983,7 +987,7 @@ void llama_context::set_embeddings(bool value) { cparams.embeddings = value; // TODO: not sure yet if we want to reserve here - //reserve(); + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { @@ -995,17 +999,27 @@ void llama_context::set_causal_attn(bool value) { cparams.causal_attn = value; - reserve(); + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + sched_need_reserve = true; } bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { - LLAMA_LOG_ERROR("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); const bool can_offload = sampler && @@ -1024,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers[seq_id] = sampler; + sched_need_reserve = true; + return true; } if (sampler && !can_offload) { LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); return false; @@ -1053,7 +1073,7 @@ void llama_context::set_adapter_lora( loras[adapter] = scale; - reserve(); + sched_need_reserve = true; } bool llama_context::rm_adapter_lora( @@ -1064,7 +1084,7 @@ bool llama_context::rm_adapter_lora( if (it != loras.end()) { loras.erase(it); - reserve(); + sched_need_reserve = true; return true; } @@ -1081,7 +1101,7 @@ void llama_context::clear_adapter_lora() { loras.clear(); - reserve(); + sched_need_reserve = true; } bool llama_context::apply_adapter_cvec( @@ -1196,6 +1216,8 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer @@ -1235,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1509,6 +1531,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies diff --git a/src/llama-context.h b/src/llama-context.h index 960e4a0782..86decc05fb 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -40,14 +40,13 @@ struct llama_context { ~llama_context(); - // reserve a new backend scheduler - // recommended to call whenver the context changes in such a way that the compute graph is modified. - // for example: + // reserve a new backend scheduler (if needed) + // for example, when: // - changing loras // - changing samplers // - changing attention type // - etc. - void reserve(); + void sched_reserve(); void synchronize(); @@ -323,6 +322,8 @@ private: ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index af6e053424..6520fa3ff5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1182,7 +1182,7 @@ private: SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); // initialize samplers - { + if (task.uses_sampling()) { slot.smpl.reset(common_sampler_init(model, task.params.sampling)); if (slot.smpl == nullptr) { @@ -1211,6 +1211,8 @@ private: } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); + } else { + slot.smpl.reset(); } // initialize draft batch @@ -2593,6 +2595,12 @@ private: llama_set_embeddings(ctx, slot_batched->need_embd()); } + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.smpl) { + llama_set_sampler(ctx, slot.id, nullptr); + } + } + if (batch.n_tokens == 0) { SRV_WRN("%s", "no tokens to decode\n"); } @@ -2727,6 +2735,8 @@ private: continue; // continue loop of slots } + GGML_ASSERT(slot.task->uses_sampling()); + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index cf08fced63..954495006e 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -156,6 +156,11 @@ struct server_task { return tokens.size(); } + bool uses_sampling() const { + return type != SERVER_TASK_TYPE_EMBEDDING && + type != SERVER_TASK_TYPE_RERANK; + } + static task_params params_from_json_cmpl( const llama_vocab * vocab, const common_params & params_base,