server: improve slots scheduling for n_cmpl (#18789)

* server : make sure children tasks are scheduled to launch with parent

* fix

* add comment pointing to this PR

* fix

* clean up

* more debug messages

* add pop_deferred_task with specific ID version

* improve the logic

* simple approach

* no double move

* correct return type of launch_slots_with_parent_task
This commit is contained in:
Xuan-Son Nguyen 2026-01-15 17:10:28 +01:00 committed by GitHub
parent 39173bcacb
commit a04c2b06a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 194 additions and 103 deletions

View File

@ -158,7 +158,7 @@ struct server_slot {
double t_prompt_processing; // ms
double t_token_generation; // ms
std::function<void(int)> callback_on_release;
std::function<void(int /* slot_id */)> callback_on_release;
// Speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
@ -298,17 +298,6 @@ struct server_slot {
return n_draft_max;
}
// note: a slot can also be either a parent or a child
// TODO: move to server_task
bool is_parent() const {
return task->n_children > 0;
}
// TODO: move to server_task
bool is_child() const {
return task->id_parent >= 0;
}
void release() {
if (is_processing()) {
GGML_ASSERT(task);
@ -321,7 +310,7 @@ struct server_slot {
state = SLOT_STATE_IDLE;
// do not keep context of the child slots - the parent's context is enough
if (is_child()) {
if (task->is_child()) {
prompt_clear(false);
}
@ -805,8 +794,8 @@ private:
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
slot.callback_on_release = [this](int slot_id) {
queue_tasks.pop_deferred_task(slot_id);
};
slot.reset();
@ -920,9 +909,9 @@ private:
return true;
}
server_slot * get_slot_by_id(int id) {
server_slot * get_slot_by_id(int id_slot) {
for (server_slot & slot : slots) {
if (slot.id == id) {
if (slot.id == id_slot) {
return &slot;
}
}
@ -1196,12 +1185,11 @@ private:
slot.task = std::make_unique<const server_task>(std::move(task));
slot.state = slot.is_child()
slot.state = slot.task->is_child()
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;
SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());
SLT_INF(slot, "processing task, is_child = %d\n", slot.task->is_child());
return true;
}
@ -1596,9 +1584,7 @@ private:
// tokenize the input if it's set by CLI, return false on error
bool tokenize_cli_input(server_task & task) {
if (task.cli_input == nullptr) {
return true; // nothing to do
}
GGML_ASSERT(task.cli_input != nullptr);
try {
auto & opt = oai_parser_opt;
common_chat_templates_inputs inputs;
@ -1632,6 +1618,64 @@ private:
return true;
}
std::vector<server_slot *> get_free_slots(size_t n_slots_needed, int exclude_id_slot) {
std::vector<server_slot *> free_slots;
for (auto & slot : slots) {
if (!slot.is_processing() && slot.id != exclude_id_slot) {
free_slots.push_back(&slot);
}
if (free_slots.size() >= n_slots_needed) {
break;
}
}
return free_slots;
}
// launch multiple slots for parent + child tasks
bool launch_slots_with_parent_task(server_slot & parent_slot, std::vector<server_slot *> & child_slots, server_task && parent_task) {
GGML_ASSERT(!parent_slot.is_processing());
GGML_ASSERT(parent_task.is_parent());
GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
int id_parent = parent_task.id;
SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size());
// to be called in case of failure to release all launched slots
auto release_slots = [this, id_parent]() {
for (auto & slot : slots) {
if (slot.is_processing() && (
slot.task->id == id_parent ||
slot.task->id_parent == id_parent
)) {
slot.release();
}
}
};
// launch all child tasks first
size_t idx = 0;
GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size());
for (auto * slot : child_slots) {
int id_child = parent_task.child_tasks[idx].id;
if (!launch_slot_with_task(*slot, std::move(parent_task.child_tasks[idx]))) {
SRV_ERR("failed to launch slot with child task, id_task = %d\n", id_child);
release_slots();
return false;
}
idx++;
}
// finally, launch the parent task
if (!launch_slot_with_task(parent_slot, std::move(parent_task))) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", id_parent);
release_slots();
return false;
}
return true;
}
void process_single_task(server_task && task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@ -1639,31 +1683,55 @@ private:
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
if (!tokenize_cli_input(task)) {
break;
// special case: if input is provided via CLI, tokenize it first
// otherwise, no need to tokenize as it's already done inside the HTTP thread
if (task.cli_input != nullptr) {
if (!tokenize_cli_input(task)) {
break;
}
}
const int id_slot = task.id_slot;
const int id_task = task.id;
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
server_slot * slot = id_slot != -1
? get_slot_by_id(id_slot)
: get_available_slot(task);
//
// slot scheduling logic
//
if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
SRV_DBG("no slot is available, defer task, id_task = %d\n", id_task);
queue_tasks.defer(std::move(task));
break;
}
if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", id_task);
queue_tasks.defer(std::move(task));
break;
}
if (!launch_slot_with_task(*slot, std::move(task))) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
break;
if (task.is_parent()) {
// try getting free slots for all child tasks
size_t n_child_tasks = task.child_tasks.size();
std::vector<server_slot *> child_slots = get_free_slots(n_child_tasks, slot->id);
if (child_slots.size() < n_child_tasks) {
SRV_DBG("not enough free slots for child tasks, n_free = %zu, n_children = %zu, defer task, id_task = %d\n", child_slots.size(), n_child_tasks, id_task);
queue_tasks.defer(std::move(task));
break;
}
if (!launch_slots_with_parent_task(*slot, child_slots, std::move(task))) {
SRV_ERR("failed to launch slot with parent task, id_task = %d\n", id_task);
break; // drop the task
}
} else if (!launch_slot_with_task(*slot, std::move(task))) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", id_task);
break; // drop the task
}
} break;
case SERVER_TASK_TYPE_CANCEL:
@ -1932,7 +2000,7 @@ private:
GGML_ABORT("not supported by multimodal");
}
if (slot.is_parent() || slot.is_child()) {
if (slot.task->is_parent() || slot.task->is_child()) {
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
slot.release();
continue;
@ -2079,21 +2147,6 @@ private:
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
// wait for all children to be launched
if (slot.is_parent()) {
int n_launched = 0;
for (auto & other : slots) {
if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
++n_launched;
}
}
if (n_launched < slot.task->n_children) {
SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
continue;
}
}
const auto & input_tokens = slot.task->tokens;
// TODO: maybe move branch to outside of this loop in the future
@ -2647,9 +2700,7 @@ private:
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
std::vector<server_slot *> children;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
@ -2657,17 +2708,15 @@ private:
}
}
// we can only proceed if all child slots are having the correct tasks
if (slot.task->n_children == (int) children.size()) {
// copy state to the child slots
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);
// all children slots should already launched by launch_slots_with_parent_task()
// copy state to the child slots
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}
@ -2943,7 +2992,9 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
@ -2964,23 +3015,13 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// prepare child tasks
if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (int j = 0; j < task.n_children; j++) {
server_task child = task.create_child(task.id, rd.get_new_id());
// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
child.params.sampling.seed += j + 1;
}
tasks.push_back(std::move(child));
int n_children = task.params.n_cmpl - 1;
for (int j = 0; j < n_children; j++) {
task.add_child(task.id, rd.get_new_id());
}
}
// note: the parent task always launches first
tasks.insert(tasks.begin(), std::move(task));
tasks.push_back(std::move(task));
}
rd.post_tasks(std::move(tasks));

View File

@ -74,11 +74,26 @@ int server_queue::get_new_id() {
return new_id;
}
void server_queue::pop_deferred_task() {
void server_queue::pop_deferred_task(int id_slot) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!queue_tasks_deferred.empty()) {
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
queue_tasks_deferred.pop_front();
// try to find a task that uses the specified slot
bool found = false;
for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
if (it->id_slot == id_slot) {
QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
queue_tasks.emplace_front(std::move(*it));
queue_tasks_deferred.erase(it);
found = true;
break;
}
}
// if not tasks found using the slot, just pop the first deferred task (default behavior)
if (!found) {
QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
queue_tasks_deferred.pop_front();
}
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
@ -217,12 +232,12 @@ void server_response::add_waiting_task_id(int id_task) {
waiting_task_ids.insert(id_task);
}
void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & task : tasks) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
waiting_task_ids.insert(task.id);
for (const auto & id_task : id_tasks) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.insert(id_task);
}
}
@ -327,6 +342,7 @@ void server_response::terminate() {
void server_response_reader::post_task(server_task && task, bool front) {
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
task.index = 0;
id_tasks.insert(task.id);
states.push_back(task.create_state());
@ -338,11 +354,18 @@ void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks);
states.reserve(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
tasks[i].index = i;
states.push_back(tasks[i].create_state());
size_t index = 0;
for (auto & task : tasks) {
task.index = index++;
states.push_back(task.create_state());
// for child tasks
for (auto & child_task : task.child_tasks) {
child_task.index = index++;
states.push_back(child_task.create_state());
}
}
queue_results.add_waiting_tasks(tasks);
GGML_ASSERT(states.size() == id_tasks.size());
queue_results.add_waiting_task_ids(id_tasks);
queue_tasks.post(std::move(tasks), front);
}

View File

@ -44,7 +44,8 @@ public:
int get_new_id();
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task();
// prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
void pop_deferred_task(int id_slot);
// if sleeping, request exiting sleep state and wait until it is done
// returns immediately if not sleeping
@ -124,7 +125,7 @@ public:
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task);
void add_waiting_tasks(const std::vector<server_task> & tasks);
void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task);

View File

@ -121,8 +121,10 @@ struct server_task {
int id_slot = -1;
// used by parallel sampling (multiple completions from same prompt)
int n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;
// temporary store of child tasks for scheduling
// note: accessing to elements is invalid after the task is moved to server_slot
std::vector<server_task> child_tasks;
// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
@ -197,11 +199,14 @@ struct server_task {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
for (auto & child : tasks[i].child_tasks) {
ids.insert(child.id);
}
}
return ids;
}
server_task create_child(int id_parent, int id_child) const {
void add_child(int id_parent, int id_child) {
server_task copy;
copy.id = id_child;
@ -209,8 +214,15 @@ struct server_task {
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();
copy.id_slot = -1; // child tasks cannot specify slot
return copy;
// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
}
child_tasks.push_back(std::move(copy));
}
// the task will be moved into queue, then onto slots
@ -218,6 +230,14 @@ struct server_task {
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
}
bool is_parent() const {
return child_tasks.size() > 0;
}
bool is_child() const {
return id_parent != -1;
}
};
struct result_timings {

View File

@ -491,16 +491,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
def test_chat_completions_multiple_choices():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"n": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
})
assert res.status_code == 200
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert choice["finish_reason"] == "length"
# make sure cache can be reused across multiple choices and multiple requests
# ref: https://github.com/ggml-org/llama.cpp/pull/18663
for _ in range(2):
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"n": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
# test forcing the same slot to be used
# the scheduler should not be locked up in this case
"id_slot": 0,
})
assert res.status_code == 200
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert choice["finish_reason"] == "length"