llama : reserve on sampler changes

This commit is contained in:
Georgi Gerganov 2026-01-12 16:21:00 +02:00
parent 5260bb79c0
commit 0c0d0fdc30
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 57 additions and 20 deletions

View File

@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
} }
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) { if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data(); cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size(); cparams.n_samplers = pimpl->samplers_seq_config.size();

View File

@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
sampler_configs.push_back({ i, smpl }); sampler_configs.push_back({ i, smpl });
} }
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) { if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data(); ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size(); ctx_params.n_samplers = sampler_configs.size();

View File

@ -1255,7 +1255,6 @@ extern "C" {
// [EXPERIMENTAL] // [EXPERIMENTAL]
// attach a sampler to the context // attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible // note: prefer initializing the context with llama_context_params.samplers when possible
// note: changing the samplers of a context can cause graph reallocations and degraded performance
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i: // mirror of llama_sampler_i:

View File

@ -340,7 +340,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
} }
reserve(); sched_reserve();
if (!cparams.flash_attn) { if (!cparams.flash_attn) {
if (ggml_is_quantized(params.type_v)) { if (ggml_is_quantized(params.type_v)) {
@ -380,7 +380,13 @@ llama_context::~llama_context() {
ggml_opt_free(opt_ctx); ggml_opt_free(opt_ctx);
} }
void llama_context::reserve() { void llama_context::sched_reserve() {
if (!sched_need_reserve) {
return;
}
sched_need_reserve = false;
LLAMA_LOG_INFO("%s: reserving ...\n", __func__); LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
synchronize(); synchronize();
@ -408,10 +414,8 @@ void llama_context::reserve() {
} }
} }
cross.v_embd.clear();
// avoid reserving graphs with zero outputs - assume one output per sequence // avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs; const int n_outputs = n_seqs;
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@ -983,7 +987,7 @@ void llama_context::set_embeddings(bool value) {
cparams.embeddings = value; cparams.embeddings = value;
// TODO: not sure yet if we want to reserve here // TODO: not sure yet if we want to reserve here
//reserve(); //sched_need_reserve = true;
} }
void llama_context::set_causal_attn(bool value) { void llama_context::set_causal_attn(bool value) {
@ -995,17 +999,27 @@ void llama_context::set_causal_attn(bool value) {
cparams.causal_attn = value; cparams.causal_attn = value;
reserve(); sched_need_reserve = true;
} }
void llama_context::set_warmup(bool value) { void llama_context::set_warmup(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
if (cparams.warmup == value) {
return;
}
cparams.warmup = value; cparams.warmup = value;
sched_need_reserve = true;
} }
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_ERROR("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); if (!sampler && sampling.samplers.count(seq_id) == 0) {
return true;
}
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
const bool can_offload = const bool can_offload =
sampler && sampler &&
@ -1024,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
sampling.samplers[seq_id] = sampler; sampling.samplers[seq_id] = sampler;
sched_need_reserve = true;
return true; return true;
} }
if (sampler && !can_offload) { if (sampler && !can_offload) {
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
if (sampling.samplers.count(seq_id) > 0) {
sched_need_reserve = true;
}
sampling.samplers.erase(seq_id); sampling.samplers.erase(seq_id);
return false; return false;
@ -1053,7 +1073,7 @@ void llama_context::set_adapter_lora(
loras[adapter] = scale; loras[adapter] = scale;
reserve(); sched_need_reserve = true;
} }
bool llama_context::rm_adapter_lora( bool llama_context::rm_adapter_lora(
@ -1064,7 +1084,7 @@ bool llama_context::rm_adapter_lora(
if (it != loras.end()) { if (it != loras.end()) {
loras.erase(it); loras.erase(it);
reserve(); sched_need_reserve = true;
return true; return true;
} }
@ -1081,7 +1101,7 @@ void llama_context::clear_adapter_lora() {
loras.clear(); loras.clear();
reserve(); sched_need_reserve = true;
} }
bool llama_context::apply_adapter_cvec( bool llama_context::apply_adapter_cvec(
@ -1196,6 +1216,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better // TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear(); embd_seq.clear();
sched_reserve();
n_queued_tokens += n_tokens; n_queued_tokens += n_tokens;
// reserve output buffer // reserve output buffer
@ -1235,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits // extract logits
if (logits && t_logits) { if (logits && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
@ -1509,6 +1531,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
embd_seq.clear(); embd_seq.clear();
output_swaps.clear(); output_swaps.clear();
sched_reserve();
bool did_optimize = false; bool did_optimize = false;
// handle any pending shifts/copies // handle any pending shifts/copies

View File

@ -40,14 +40,13 @@ struct llama_context {
~llama_context(); ~llama_context();
// reserve a new backend scheduler // reserve a new backend scheduler (if needed)
// recommended to call whenver the context changes in such a way that the compute graph is modified. // for example, when:
// for example:
// - changing loras // - changing loras
// - changing samplers // - changing samplers
// - changing attention type // - changing attention type
// - etc. // - etc.
void reserve(); void sched_reserve();
void synchronize(); void synchronize();
@ -323,6 +322,8 @@ private:
ggml_backend_sched_ptr sched; ggml_backend_sched_ptr sched;
bool sched_need_reserve = true;
ggml_backend_t backend_cpu = nullptr; ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends; std::vector<ggml_backend_ptr> backends;

View File

@ -1182,7 +1182,7 @@ private:
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
// initialize samplers // initialize samplers
{ if (task.uses_sampling()) {
slot.smpl.reset(common_sampler_init(model, task.params.sampling)); slot.smpl.reset(common_sampler_init(model, task.params.sampling));
if (slot.smpl == nullptr) { if (slot.smpl == nullptr) {
@ -1211,6 +1211,8 @@ private:
} }
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
} else {
slot.smpl.reset();
} }
// initialize draft batch // initialize draft batch
@ -2593,6 +2595,12 @@ private:
llama_set_embeddings(ctx, slot_batched->need_embd()); llama_set_embeddings(ctx, slot_batched->need_embd());
} }
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.smpl) {
llama_set_sampler(ctx, slot.id, nullptr);
}
}
if (batch.n_tokens == 0) { if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n"); SRV_WRN("%s", "no tokens to decode\n");
} }
@ -2727,6 +2735,8 @@ private:
continue; // continue loop of slots continue; // continue loop of slots
} }
GGML_ASSERT(slot.task->uses_sampling());
// prompt evaluated for next-token prediction // prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING; slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) { } else if (slot.state != SLOT_STATE_GENERATING) {

View File

@ -156,6 +156,11 @@ struct server_task {
return tokens.size(); return tokens.size();
} }
bool uses_sampling() const {
return type != SERVER_TASK_TYPE_EMBEDDING &&
type != SERVER_TASK_TYPE_RERANK;
}
static task_params params_from_json_cmpl( static task_params params_from_json_cmpl(
const llama_vocab * vocab, const llama_vocab * vocab,
const common_params & params_base, const common_params & params_base,