From 39173bcacb67329850b9ff3108dd036eafb680f0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 15 Jan 2026 16:39:17 +0200 Subject: [PATCH] context : reserve new scheduler when graph topology changes (#18547) * context : reserve new scheduler when graph topology changes * cont : fix * cont : fix reserve * cont : reserve only when changes occur + timing * context : add comments * llama : reserve on sampler changes * common : allow null common_sampler * server : task declares needs (embd, logits, sampling) * server : do not init sampler if not needed * llama : fix need_reserve when unsetting a sampler * server : consolidate slot reset/clear logic --- common/common.cpp | 1 - common/sampling.cpp | 24 +- examples/batched/batched.cpp | 1 - include/llama.h | 1 - src/llama-context.cpp | 374 +++++++++++++++++++------------- src/llama-context.h | 10 + src/llama-cparams.h | 2 + tools/server/server-context.cpp | 101 ++++----- tools/server/server-task.h | 30 +++ 9 files changed, 328 insertions(+), 216 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 744f0b4eeb..26250abb6c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) : pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; } - // TODO: temporarily gated behind a flag if (params.sampling.backend_sampling) { cparams.samplers = pimpl->samplers_seq_config.data(); cparams.n_samplers = pimpl->samplers_seq_config.size(); diff --git a/common/sampling.cpp b/common/sampling.cpp index 8a931d51fc..54f6377a1e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } void common_sampler_free(struct common_sampler * gsmpl) { - if (gsmpl) { - llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->chain); - - delete gsmpl; + if (!gsmpl) { + return; } + + llama_sampler_free(gsmpl->grmr); + llama_sampler_free(gsmpl->chain); + + delete gsmpl; } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (!gsmpl) { + return; + } + const auto tm = gsmpl->tm(); if (gsmpl->grmr && accept_grammar) { @@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo } void common_sampler_reset(struct common_sampler * gsmpl) { + if (!gsmpl) { + return; + } + gsmpl->reset(); } @@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + if (!gsmpl) { + return nullptr; + } + return gsmpl->chain; } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 6875038770..a8c19a6aba 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -81,7 +81,6 @@ int main(int argc, char ** argv) { sampler_configs.push_back({ i, smpl }); } - // TODO: temporarily gated behind a flag if (params.sampling.backend_sampling) { ctx_params.samplers = sampler_configs.data(); ctx_params.n_samplers = sampler_configs.size(); diff --git a/include/llama.h b/include/llama.h index a25237d20b..bff788f92a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1256,7 +1256,6 @@ extern "C" { // [EXPERIMENTAL] // attach a sampler to the context // note: prefer initializing the context with llama_context_params.samplers when possible - // note: changing the samplers of a context can cause graph reallocations and degraded performance LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 661e7da168..12e40018bb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -146,6 +146,7 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -155,6 +156,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // intialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -302,16 +306,6 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = @@ -340,143 +334,19 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { + if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + sched_reserve(); + + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } } - - cross.v_embd.clear(); - - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; - - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); - - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); - } - - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); - - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } - } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); - } - } - - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); - if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); - } - - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } - - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } - - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } - } - - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } - - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); - } } // Initialize the full vocabulary token ids for backend samplers. @@ -510,7 +380,171 @@ llama_context::~llama_context() { ggml_opt_free(opt_ctx); } +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } + + sched_need_reserve = false; + + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + + cparams.auto_fa = false; + } + + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms\n", __func__, (t_end_us - t_start_us)/1000.0); +} + void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -951,21 +985,40 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + sched_need_reserve = true; } bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); const bool can_offload = @@ -985,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers[seq_id] = sampler; + sched_need_reserve = true; + return true; } if (sampler && !can_offload) { LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); return false; @@ -998,6 +1057,8 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers.erase(seq_id); + sched_need_reserve = true; + return true; } @@ -1006,16 +1067,27 @@ void llama_context::set_adapter_lora( float scale) { LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); + if (auto it = loras.find(adapter); it != loras.end()) { + if (it->second == scale) { + return; + } + } + loras[adapter] = scale; + + sched_need_reserve = true; } bool llama_context::rm_adapter_lora( llama_adapter_lora * adapter) { LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); + auto it = loras.find(adapter); + if (it != loras.end()) { + loras.erase(it); + + sched_need_reserve = true; + return true; } @@ -1025,7 +1097,13 @@ bool llama_context::rm_adapter_lora( void llama_context::clear_adapter_lora() { LLAMA_LOG_DEBUG("%s: call\n", __func__); + if (loras.empty()) { + return; + } + loras.clear(); + + sched_need_reserve = true; } bool llama_context::apply_adapter_cvec( @@ -1036,6 +1114,8 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); + // TODO: should we reserve? + return cvec.apply(model, data, len, n_embd, il_start, il_end); } @@ -1138,6 +1218,8 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer @@ -1177,7 +1259,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1451,6 +1533,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies diff --git a/src/llama-context.h b/src/llama-context.h index b29edf4db2..86decc05fb 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -40,6 +40,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -314,6 +322,8 @@ private: ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcef8fa976..2da3bbd6f9 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,10 +30,12 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index af6e053424..d968a94a81 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -45,26 +45,6 @@ enum server_state { SERVER_STATE_READY, // Server is ready and model is loaded }; -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - struct server_slot { int id; @@ -147,6 +127,17 @@ struct server_slot { return res; } + void prompt_clear(bool allow_processing) { + if (!allow_processing) { + GGML_ASSERT(!is_processing()); + } + + SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + prompt.tokens.clear(); + } + std::vector lora; int32_t alora_invocation_start = -1; @@ -196,30 +187,24 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + task_prev = std::move(task); task.reset(); - task_prev.reset(); + + llama_set_sampler(ctx, id, nullptr); // clear alora start alora_invocation_start = -1; } - // remove cached prompt + tokens - void clear(bool allow_processing) { - if (!allow_processing) { - GGML_ASSERT(!is_processing()); + void init_sampler() const { + common_sampler_reset(smpl.get()); + + if (!task->need_sampling()) { + return; } - SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); - prompt.tokens.clear(); - } - - void init_sampler() const { const int64_t t_start = ggml_time_us(); - common_sampler_reset(smpl.get()); - int n_text = 0; for (int i = 0; i < (int) prompt.tokens.size(); i++) { @@ -235,25 +220,13 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } - // TODO: move to server_task - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - // TODO: move to server_task - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + GGML_ASSERT(task); + return - !need_embd() || + !task->need_embd() || (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); } @@ -349,11 +322,10 @@ struct server_slot { // do not keep context of the child slots - the parent's context is enough if (is_child()) { - clear(false); + prompt_clear(false); } - task_prev = std::move(task); - task.reset(); + reset(); callback_on_release(id); } @@ -801,6 +773,7 @@ private: slots.clear(); + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -1049,7 +1022,7 @@ private: ret->prompt_save(*prompt_cache); if (!ret->prompt_load(*prompt_cache, task.tokens)) { - ret->clear(false); + ret->prompt_clear(false); } prompt_cache->update(); @@ -1081,7 +1054,7 @@ private: if (slot.prompt.n_tokens() > 0) { SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - slot.clear(false); + slot.prompt_clear(false); res = true; @@ -1107,8 +1080,6 @@ private: } bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - // process per-request lora adapters if (!task.params.lora.empty()) { auto task_loras = construct_lora_list(task.params.lora); @@ -1182,7 +1153,7 @@ private: SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); // initialize samplers - { + if (task.need_sampling()) { slot.smpl.reset(common_sampler_init(model, task.params.sampling)); if (slot.smpl == nullptr) { @@ -1211,6 +1182,8 @@ private: } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); + } else { + slot.smpl.reset(); } // initialize draft batch @@ -1864,7 +1837,7 @@ private: // Erase token cache const size_t n_erased = slot->prompt.tokens.size(); - slot->clear(false); + slot->prompt_clear(false); auto res = std::make_unique(); res->id = task.id; @@ -2161,7 +2134,7 @@ private: } // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2421,7 +2394,7 @@ private: if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - slot.clear(true); + slot.prompt_clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; @@ -2500,7 +2473,7 @@ private: cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.need_embd()); + slot.task->need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2590,7 +2563,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->need_embd()); + llama_set_embeddings(ctx, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2648,7 +2621,7 @@ private: // note: it's complicated to keep track of how much of the current batch has been // processed before the error occurred, so we simply clear the entire context - slot.clear(false); + slot.prompt_clear(false); } } @@ -2727,6 +2700,8 @@ private: continue; // continue loop of slots } + GGML_ASSERT(slot.task->need_sampling()); + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index cf08fced63..97bae920d6 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -156,6 +156,36 @@ struct server_task { return tokens.size(); } + bool need_embd() const { + switch (type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } + } + + bool need_logits() const { + switch (type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } + } + + bool need_sampling() const { + switch (type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } + } + static task_params params_from_json_cmpl( const llama_vocab * vocab, const common_params & params_base,