context : reserve new scheduler when graph topology changes (#18547)

* context : reserve new scheduler when graph topology changes

* cont : fix

* cont : fix reserve

* cont : reserve only when changes occur + timing

* context : add comments

* llama : reserve on sampler changes

* common : allow null common_sampler

* server : task declares needs (embd, logits, sampling)

* server : do not init sampler if not needed

* llama : fix need_reserve when unsetting a sampler

* server : consolidate slot reset/clear logic
This commit is contained in:
Georgi Gerganov 2026-01-15 16:39:17 +02:00 committed by GitHub
parent 5c662d21a3
commit 39173bcacb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 328 additions and 216 deletions

View File

@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();

View File

@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
if (!gsmpl) {
return;
}
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
}
const auto tm = gsmpl->tm();
if (gsmpl->grmr && accept_grammar) {
@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
}
void common_sampler_reset(struct common_sampler * gsmpl) {
if (!gsmpl) {
return;
}
gsmpl->reset();
}
@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
if (!gsmpl) {
return nullptr;
}
return gsmpl->chain;
}

View File

@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
sampler_configs.push_back({ i, smpl });
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();

View File

@ -1256,7 +1256,6 @@ extern "C" {
// [EXPERIMENTAL]
// attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible
// note: changing the samplers of a context can cause graph reallocations and degraded performance
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i:

View File

@ -146,6 +146,7 @@ llama_context::llama_context(
}
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@ -155,6 +156,9 @@ llama_context::llama_context(
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
// intialized later
cparams.pipeline_parallel = false;
{
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@ -302,16 +306,6 @@ llama_context::llama_context(
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
const size_t max_nodes = this->graph_max_nodes(n_tokens);
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));
// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
@ -340,143 +334,19 @@ llama_context::llama_context(
}
}
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
cparams.pipeline_parallel = pipeline_parallel;
if (pipeline_parallel) {
if (cparams.pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
}
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
sched_reserve();
if (!cparams.flash_attn) {
if (ggml_is_quantized(params.type_v)) {
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
}
}
cross.v_embd.clear();
// avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs;
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
// resolve automatic Flash Attention use
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
fa_device_mismatch = true;
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
if (ggml_is_quantized(params.type_v)) {
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
}
} else {
cparams.flash_attn = true;
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
}
}
// reserve worst-case graph
int n_splits_pp = -1;
int n_nodes_pp = -1;
int n_splits_tg = -1;
int n_nodes_tg = -1;
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
if (!gf) {
if (pipeline_parallel) {
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
}
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
}
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_pp = ggml_graph_n_nodes(gf);
}
// reserve with tg (token generation) graph to get the number of splits and nodes
{
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_tg = ggml_graph_n_nodes(gf);
}
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
}
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
if (!model.hparams.no_alloc) {
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
}
if (backend_buf_exp_size[i] > 1) {
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
backend_buf_exp_size[i] / 1024.0 / 1024.0);
}
}
if (n_nodes_pp == n_nodes_tg) {
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
} else {
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
}
if (n_splits_pp == n_splits_tg) {
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
} else {
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
}
// Initialize the full vocabulary token ids for backend samplers.
@ -510,7 +380,171 @@ llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
}
void llama_context::sched_reserve() {
if (!sched_need_reserve) {
return;
}
sched_need_reserve = false;
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
synchronize();
const int64_t t_start_us = ggml_time_us();
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
const size_t max_nodes = this->graph_max_nodes(n_tokens);
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
}
// avoid reserving graphs with zero outputs - assume one output per sequence
const int n_outputs = n_seqs;
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
// resolve automatic Flash Attention use
if (cparams.auto_fa) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
fa_device_mismatch = true;
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
} else {
cparams.flash_attn = true;
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
}
cparams.auto_fa = false;
}
// reserve worst-case graph
int n_splits_pp = -1;
int n_nodes_pp = -1;
int n_splits_tg = -1;
int n_nodes_tg = -1;
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
if (!gf) {
if (cparams.pipeline_parallel) {
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
cparams.pipeline_parallel = false;
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
}
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
}
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_pp = ggml_graph_n_nodes(gf);
}
// reserve with tg (token generation) graph to get the number of splits and nodes
{
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_tg = ggml_graph_n_nodes(gf);
}
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
}
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
ggml_backend_t backend = backend_ptrs[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
if (!model.hparams.no_alloc) {
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
}
if (backend_buf_exp_size[i] > 1) {
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buft_name(buft),
backend_buf_exp_size[i] / 1024.0 / 1024.0);
}
}
if (n_nodes_pp == n_nodes_tg) {
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
} else {
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
}
if (n_splits_pp == n_splits_tg) {
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
} else {
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
}
const int64_t t_end_us = ggml_time_us();
LLAMA_LOG_INFO("%s: reserve took %.2f ms\n", __func__, (t_end_us - t_start_us)/1000.0);
}
void llama_context::synchronize() {
if (!sched) {
return;
}
ggml_backend_sched_synchronize(sched.get());
// FIXME: if multiple single tokens are evaluated without a synchronization,
@ -951,21 +985,40 @@ void llama_context::set_embeddings(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
cparams.embeddings = value;
// TODO: not sure yet if we want to reserve here
//sched_need_reserve = true;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
if (cparams.causal_attn == value) {
return;
}
cparams.causal_attn = value;
sched_need_reserve = true;
}
void llama_context::set_warmup(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
if (cparams.warmup == value) {
return;
}
cparams.warmup = value;
sched_need_reserve = true;
}
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
if (!sampler && sampling.samplers.count(seq_id) == 0) {
return true;
}
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
const bool can_offload =
@ -985,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
sampling.samplers[seq_id] = sampler;
sched_need_reserve = true;
return true;
}
if (sampler && !can_offload) {
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
if (sampling.samplers.count(seq_id) > 0) {
sched_need_reserve = true;
}
sampling.samplers.erase(seq_id);
return false;
@ -998,6 +1057,8 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
sampling.samplers.erase(seq_id);
sched_need_reserve = true;
return true;
}
@ -1006,16 +1067,27 @@ void llama_context::set_adapter_lora(
float scale) {
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
if (auto it = loras.find(adapter); it != loras.end()) {
if (it->second == scale) {
return;
}
}
loras[adapter] = scale;
sched_need_reserve = true;
}
bool llama_context::rm_adapter_lora(
llama_adapter_lora * adapter) {
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
auto pos = loras.find(adapter);
if (pos != loras.end()) {
loras.erase(pos);
auto it = loras.find(adapter);
if (it != loras.end()) {
loras.erase(it);
sched_need_reserve = true;
return true;
}
@ -1025,7 +1097,13 @@ bool llama_context::rm_adapter_lora(
void llama_context::clear_adapter_lora() {
LLAMA_LOG_DEBUG("%s: call\n", __func__);
if (loras.empty()) {
return;
}
loras.clear();
sched_need_reserve = true;
}
bool llama_context::apply_adapter_cvec(
@ -1036,6 +1114,8 @@ bool llama_context::apply_adapter_cvec(
int32_t il_end) {
LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
// TODO: should we reserve?
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
@ -1138,6 +1218,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
sched_reserve();
n_queued_tokens += n_tokens;
// reserve output buffer
@ -1177,7 +1259,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits
if (logits && t_logits) {
if (logits && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
@ -1451,6 +1533,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
embd_seq.clear();
output_swaps.clear();
sched_reserve();
bool did_optimize = false;
// handle any pending shifts/copies

View File

@ -40,6 +40,14 @@ struct llama_context {
~llama_context();
// reserve a new backend scheduler (if needed)
// for example, when:
// - changing loras
// - changing samplers
// - changing attention type
// - etc.
void sched_reserve();
void synchronize();
const llama_model & get_model() const;
@ -314,6 +322,8 @@ private:
ggml_backend_sched_ptr sched;
bool sched_need_reserve = true;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;

View File

@ -30,10 +30,12 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool auto_fa;
bool no_perf;
bool warmup;
bool op_offload;
bool kv_unified;
bool pipeline_parallel;
enum llama_pooling_type pooling_type;

View File

@ -45,26 +45,6 @@ enum server_state {
SERVER_STATE_READY, // Server is ready and model is loaded
};
static bool server_task_type_need_embd(server_task_type task_type) {
switch (task_type) {
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
return true;
default:
return false;
}
}
static bool server_task_type_need_logits(server_task_type task_type) {
switch (task_type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
struct server_slot {
int id;
@ -147,6 +127,17 @@ struct server_slot {
return res;
}
void prompt_clear(bool allow_processing) {
if (!allow_processing) {
GGML_ASSERT(!is_processing());
}
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
prompt.tokens.clear();
}
std::vector<common_adapter_lora_info> lora;
int32_t alora_invocation_start = -1;
@ -196,30 +187,24 @@ struct server_slot {
n_draft_total = 0;
n_draft_accepted = 0;
task_prev = std::move(task);
task.reset();
task_prev.reset();
llama_set_sampler(ctx, id, nullptr);
// clear alora start
alora_invocation_start = -1;
}
// remove cached prompt + tokens
void clear(bool allow_processing) {
if (!allow_processing) {
GGML_ASSERT(!is_processing());
void init_sampler() const {
common_sampler_reset(smpl.get());
if (!task->need_sampling()) {
return;
}
SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
prompt.tokens.clear();
}
void init_sampler() const {
const int64_t t_start = ggml_time_us();
common_sampler_reset(smpl.get());
int n_text = 0;
for (int i = 0; i < (int) prompt.tokens.size(); i++) {
@ -235,25 +220,13 @@ struct server_slot {
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
}
// TODO: move to server_task
bool need_embd() const {
GGML_ASSERT(task);
return server_task_type_need_embd(task->type);
}
// TODO: move to server_task
bool need_logits() const {
GGML_ASSERT(task);
return server_task_type_need_logits(task->type);
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
GGML_ASSERT(task);
return
!need_embd() ||
!task->need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
}
@ -349,11 +322,10 @@ struct server_slot {
// do not keep context of the child slots - the parent's context is enough
if (is_child()) {
clear(false);
prompt_clear(false);
}
task_prev = std::move(task);
task.reset();
reset();
callback_on_release(id);
}
@ -801,6 +773,7 @@ private:
slots.clear();
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -1049,7 +1022,7 @@ private:
ret->prompt_save(*prompt_cache);
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
ret->clear(false);
ret->prompt_clear(false);
}
prompt_cache->update();
@ -1081,7 +1054,7 @@ private:
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
slot.clear(false);
slot.prompt_clear(false);
res = true;
@ -1107,8 +1080,6 @@ private:
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
// process per-request lora adapters
if (!task.params.lora.empty()) {
auto task_loras = construct_lora_list(task.params.lora);
@ -1182,7 +1153,7 @@ private:
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
// initialize samplers
{
if (task.need_sampling()) {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
if (slot.smpl == nullptr) {
@ -1211,6 +1182,8 @@ private:
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
} else {
slot.smpl.reset();
}
// initialize draft batch
@ -1864,7 +1837,7 @@ private:
// Erase token cache
const size_t n_erased = slot->prompt.tokens.size();
slot->clear(false);
slot->prompt_clear(false);
auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
@ -2161,7 +2134,7 @@ private:
}
// TODO: support memory-less logits computation
if (slot.need_logits() && !llama_get_memory(ctx)) {
if (slot.task->need_logits() && !llama_get_memory(ctx)) {
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
slot.release();
continue;
@ -2421,7 +2394,7 @@ private:
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
slot.clear(true);
slot.prompt_clear(true);
// there is no common part left
slot.n_prompt_tokens_cache = 0;
@ -2500,7 +2473,7 @@ private:
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.need_embd());
slot.task->need_embd());
slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
@ -2590,7 +2563,7 @@ private:
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx, slot_batched->need_embd());
llama_set_embeddings(ctx, slot_batched->task->need_embd());
}
if (batch.n_tokens == 0) {
@ -2648,7 +2621,7 @@ private:
// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
slot.clear(false);
slot.prompt_clear(false);
}
}
@ -2727,6 +2700,8 @@ private:
continue; // continue loop of slots
}
GGML_ASSERT(slot.task->need_sampling());
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) {

View File

@ -156,6 +156,36 @@ struct server_task {
return tokens.size();
}
bool need_embd() const {
switch (type) {
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
return true;
default:
return false;
}
}
bool need_logits() const {
switch (type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
bool need_sampling() const {
switch (type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
static task_params params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,