llama : reserve on sampler changes

This commit is contained in:
Georgi Gerganov 2026-01-12 16:21:00 +02:00
parent 5260bb79c0
commit 0c0d0fdc30
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 57 additions and 20 deletions

View File

@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();

View File

@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
sampler_configs.push_back({ i, smpl });
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();

View File

@ -1255,7 +1255,6 @@ extern "C" {
// [EXPERIMENTAL]
// attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible
// note: changing the samplers of a context can cause graph reallocations and degraded performance
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i:

View File

@ -340,7 +340,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
}
reserve();
sched_reserve();
if (!cparams.flash_attn) {
if (ggml_is_quantized(params.type_v)) {
@ -380,7 +380,13 @@ llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
}
void llama_context::reserve() {
void llama_context::sched_reserve() {
if (!sched_need_reserve) {
return;
}
sched_need_reserve = false;
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
synchronize();
@ -408,10 +414,8 @@ void llama_context::reserve() {
}
}
cross.v_embd.clear();
// avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs;
const int n_outputs = n_seqs;
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@ -983,7 +987,7 @@ void llama_context::set_embeddings(bool value) {
cparams.embeddings = value;
// TODO: not sure yet if we want to reserve here
//reserve();
//sched_need_reserve = true;
}
void llama_context::set_causal_attn(bool value) {
@ -995,17 +999,27 @@ void llama_context::set_causal_attn(bool value) {
cparams.causal_attn = value;
reserve();
sched_need_reserve = true;
}
void llama_context::set_warmup(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
if (cparams.warmup == value) {
return;
}
cparams.warmup = value;
sched_need_reserve = true;
}
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_ERROR("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
if (!sampler && sampling.samplers.count(seq_id) == 0) {
return true;
}
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
const bool can_offload =
sampler &&
@ -1024,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
sampling.samplers[seq_id] = sampler;
sched_need_reserve = true;
return true;
}
if (sampler && !can_offload) {
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
if (sampling.samplers.count(seq_id) > 0) {
sched_need_reserve = true;
}
sampling.samplers.erase(seq_id);
return false;
@ -1053,7 +1073,7 @@ void llama_context::set_adapter_lora(
loras[adapter] = scale;
reserve();
sched_need_reserve = true;
}
bool llama_context::rm_adapter_lora(
@ -1064,7 +1084,7 @@ bool llama_context::rm_adapter_lora(
if (it != loras.end()) {
loras.erase(it);
reserve();
sched_need_reserve = true;
return true;
}
@ -1081,7 +1101,7 @@ void llama_context::clear_adapter_lora() {
loras.clear();
reserve();
sched_need_reserve = true;
}
bool llama_context::apply_adapter_cvec(
@ -1196,6 +1216,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
sched_reserve();
n_queued_tokens += n_tokens;
// reserve output buffer
@ -1235,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits
if (logits && t_logits) {
if (logits && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
@ -1509,6 +1531,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
embd_seq.clear();
output_swaps.clear();
sched_reserve();
bool did_optimize = false;
// handle any pending shifts/copies

View File

@ -40,14 +40,13 @@ struct llama_context {
~llama_context();
// reserve a new backend scheduler
// recommended to call whenver the context changes in such a way that the compute graph is modified.
// for example:
// reserve a new backend scheduler (if needed)
// for example, when:
// - changing loras
// - changing samplers
// - changing attention type
// - etc.
void reserve();
void sched_reserve();
void synchronize();
@ -323,6 +322,8 @@ private:
ggml_backend_sched_ptr sched;
bool sched_need_reserve = true;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;

View File

@ -1182,7 +1182,7 @@ private:
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
// initialize samplers
{
if (task.uses_sampling()) {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
if (slot.smpl == nullptr) {
@ -1211,6 +1211,8 @@ private:
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
} else {
slot.smpl.reset();
}
// initialize draft batch
@ -2593,6 +2595,12 @@ private:
llama_set_embeddings(ctx, slot_batched->need_embd());
}
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.smpl) {
llama_set_sampler(ctx, slot.id, nullptr);
}
}
if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
}
@ -2727,6 +2735,8 @@ private:
continue; // continue loop of slots
}
GGML_ASSERT(slot.task->uses_sampling());
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) {

View File

@ -156,6 +156,11 @@ struct server_task {
return tokens.size();
}
bool uses_sampling() const {
return type != SERVER_TASK_TYPE_EMBEDDING &&
type != SERVER_TASK_TYPE_RERANK;
}
static task_params params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,