From ffa0d15e869958d92500e777050ad27c56af9f03 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 14 Jan 2026 11:49:32 +0200 Subject: [PATCH] server : task declares needs (embd, logits, sampling) --- tools/server/server-context.cpp | 48 ++++++--------------------------- tools/server/server-task.h | 31 ++++++++++++++++++--- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6520fa3ff5..9973572f62 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -45,26 +45,6 @@ enum server_state { SERVER_STATE_READY, // Server is ready and model is loaded }; -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - struct server_slot { int id; @@ -235,25 +215,13 @@ struct server_slot { (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); } - // TODO: move to server_task - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - // TODO: move to server_task - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + GGML_ASSERT(task); + return - !need_embd() || + !task->need_embd() || (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); } @@ -1182,7 +1150,7 @@ private: SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); // initialize samplers - if (task.uses_sampling()) { + if (task.need_sampling()) { slot.smpl.reset(common_sampler_init(model, task.params.sampling)); if (slot.smpl == nullptr) { @@ -2163,7 +2131,7 @@ private: } // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2502,7 +2470,7 @@ private: cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.need_embd()); + slot.task->need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2592,7 +2560,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->need_embd()); + llama_set_embeddings(ctx, slot_batched->task->need_embd()); } for (auto & slot : slots) { @@ -2735,7 +2703,7 @@ private: continue; // continue loop of slots } - GGML_ASSERT(slot.task->uses_sampling()); + GGML_ASSERT(slot.task->need_sampling()); // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 954495006e..97bae920d6 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -156,9 +156,34 @@ struct server_task { return tokens.size(); } - bool uses_sampling() const { - return type != SERVER_TASK_TYPE_EMBEDDING && - type != SERVER_TASK_TYPE_RERANK; + bool need_embd() const { + switch (type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } + } + + bool need_logits() const { + switch (type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } + } + + bool need_sampling() const { + switch (type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } } static task_params params_from_json_cmpl(