diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index d968a94a81..62b12b5068 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -158,7 +158,7 @@ struct server_slot { double t_prompt_processing; // ms double t_token_generation; // ms - std::function callback_on_release; + std::function callback_on_release; // Speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated @@ -298,17 +298,6 @@ struct server_slot { return n_draft_max; } - // note: a slot can also be either a parent or a child - // TODO: move to server_task - bool is_parent() const { - return task->n_children > 0; - } - - // TODO: move to server_task - bool is_child() const { - return task->id_parent >= 0; - } - void release() { if (is_processing()) { GGML_ASSERT(task); @@ -321,7 +310,7 @@ struct server_slot { state = SLOT_STATE_IDLE; // do not keep context of the child slots - the parent's context is enough - if (is_child()) { + if (task->is_child()) { prompt_clear(false); } @@ -805,8 +794,8 @@ private: SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); + slot.callback_on_release = [this](int slot_id) { + queue_tasks.pop_deferred_task(slot_id); }; slot.reset(); @@ -920,9 +909,9 @@ private: return true; } - server_slot * get_slot_by_id(int id) { + server_slot * get_slot_by_id(int id_slot) { for (server_slot & slot : slots) { - if (slot.id == id) { + if (slot.id == id_slot) { return &slot; } } @@ -1196,12 +1185,11 @@ private: slot.task = std::make_unique(std::move(task)); - slot.state = slot.is_child() + slot.state = slot.task->is_child() ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; - SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child()); - + SLT_INF(slot, "processing task, is_child = %d\n", slot.task->is_child()); return true; } @@ -1596,9 +1584,7 @@ private: // tokenize the input if it's set by CLI, return false on error bool tokenize_cli_input(server_task & task) { - if (task.cli_input == nullptr) { - return true; // nothing to do - } + GGML_ASSERT(task.cli_input != nullptr); try { auto & opt = oai_parser_opt; common_chat_templates_inputs inputs; @@ -1632,6 +1618,64 @@ private: return true; } + std::vector get_free_slots(size_t n_slots_needed, int exclude_id_slot) { + std::vector free_slots; + for (auto & slot : slots) { + if (!slot.is_processing() && slot.id != exclude_id_slot) { + free_slots.push_back(&slot); + } + if (free_slots.size() >= n_slots_needed) { + break; + } + } + return free_slots; + } + + // launch multiple slots for parent + child tasks + bool launch_slots_with_parent_task(server_slot & parent_slot, std::vector & child_slots, server_task && parent_task) { + GGML_ASSERT(!parent_slot.is_processing()); + GGML_ASSERT(parent_task.is_parent()); + GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size()); + + int id_parent = parent_task.id; + + SRV_INF("launching slots for parent task id_task = %d with %zu child tasks\n", id_parent, parent_task.child_tasks.size()); + + // to be called in case of failure to release all launched slots + auto release_slots = [this, id_parent]() { + for (auto & slot : slots) { + if (slot.is_processing() && ( + slot.task->id == id_parent || + slot.task->id_parent == id_parent + )) { + slot.release(); + } + } + }; + + // launch all child tasks first + size_t idx = 0; + GGML_ASSERT(child_slots.size() == parent_task.child_tasks.size()); + for (auto * slot : child_slots) { + int id_child = parent_task.child_tasks[idx].id; + if (!launch_slot_with_task(*slot, std::move(parent_task.child_tasks[idx]))) { + SRV_ERR("failed to launch slot with child task, id_task = %d\n", id_child); + release_slots(); + return false; + } + idx++; + } + + // finally, launch the parent task + if (!launch_slot_with_task(parent_slot, std::move(parent_task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", id_parent); + release_slots(); + return false; + } + + return true; + } + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: @@ -1639,31 +1683,55 @@ private: case SERVER_TASK_TYPE_EMBEDDING: case SERVER_TASK_TYPE_RERANK: { - if (!tokenize_cli_input(task)) { - break; + // special case: if input is provided via CLI, tokenize it first + // otherwise, no need to tokenize as it's already done inside the HTTP thread + if (task.cli_input != nullptr) { + if (!tokenize_cli_input(task)) { + break; + } } const int id_slot = task.id_slot; + const int id_task = task.id; - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + server_slot * slot = id_slot != -1 + ? get_slot_by_id(id_slot) + : get_available_slot(task); + + // + // slot scheduling logic + // if (slot == nullptr) { // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + SRV_DBG("no slot is available, defer task, id_task = %d\n", id_task); queue_tasks.defer(std::move(task)); break; } if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", id_task); queue_tasks.defer(std::move(task)); break; } - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; + if (task.is_parent()) { + // try getting free slots for all child tasks + size_t n_child_tasks = task.child_tasks.size(); + std::vector child_slots = get_free_slots(n_child_tasks, slot->id); + if (child_slots.size() < n_child_tasks) { + SRV_DBG("not enough free slots for child tasks, n_free = %zu, n_children = %zu, defer task, id_task = %d\n", child_slots.size(), n_child_tasks, id_task); + queue_tasks.defer(std::move(task)); + break; + } + if (!launch_slots_with_parent_task(*slot, child_slots, std::move(task))) { + SRV_ERR("failed to launch slot with parent task, id_task = %d\n", id_task); + break; // drop the task + } + } else if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", id_task); + break; // drop the task } } break; case SERVER_TASK_TYPE_CANCEL: @@ -1932,7 +2000,7 @@ private: GGML_ABORT("not supported by multimodal"); } - if (slot.is_parent() || slot.is_child()) { + if (slot.task->is_parent() || slot.task->is_child()) { send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2079,21 +2147,6 @@ private: // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - // wait for all children to be launched - if (slot.is_parent()) { - int n_launched = 0; - for (auto & other : slots) { - if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) { - ++n_launched; - } - } - - if (n_launched < slot.task->n_children) { - SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched); - continue; - } - } - const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future @@ -2647,9 +2700,7 @@ private: // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { - SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children); - + if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { std::vector children; for (auto & other : slots) { if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { @@ -2657,17 +2708,15 @@ private: } } - // we can only proceed if all child slots are having the correct tasks - if (slot.task->n_children == (int) children.size()) { - // copy state to the child slots - for (auto & child : children) { - SLT_INF(slot, " - copying state to child %d\n", child->id); + // all children slots should already launched by launch_slots_with_parent_task() + // copy state to the child slots + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); - GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } + slot.copy_state_to(*child); + child->state = SLOT_STATE_DONE_PROMPT; } } } @@ -2943,7 +2992,9 @@ std::unique_ptr server_routes::handle_completions_impl( // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - tasks.reserve(inputs.size()); + + // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -2964,23 +3015,13 @@ std::unique_ptr server_routes::handle_completions_impl( // prepare child tasks if (task.params.n_cmpl > 1) { - task.n_children = task.params.n_cmpl - 1; - - for (int j = 0; j < task.n_children; j++) { - server_task child = task.create_child(task.id, rd.get_new_id()); - - // use different sampling seed for each child - // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723 - if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) { - child.params.sampling.seed += j + 1; - } - - tasks.push_back(std::move(child)); + int n_children = task.params.n_cmpl - 1; + for (int j = 0; j < n_children; j++) { + task.add_child(task.id, rd.get_new_id()); } } - // note: the parent task always launches first - tasks.insert(tasks.begin(), std::move(task)); + tasks.push_back(std::move(task)); } rd.post_tasks(std::move(tasks)); diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 9a6ba560a3..a2a026a12c 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -74,11 +74,26 @@ int server_queue::get_new_id() { return new_id; } -void server_queue::pop_deferred_task() { +void server_queue::pop_deferred_task(int id_slot) { std::unique_lock lock(mutex_tasks); if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); + // try to find a task that uses the specified slot + bool found = false; + for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) { + if (it->id_slot == id_slot) { + QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id); + queue_tasks.emplace_front(std::move(*it)); + queue_tasks_deferred.erase(it); + found = true; + break; + } + } + // if not tasks found using the slot, just pop the first deferred task (default behavior) + if (!found) { + QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id); + queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } } time_last_task = ggml_time_ms(); condition_tasks.notify_one(); @@ -217,12 +232,12 @@ void server_response::add_waiting_task_id(int id_task) { waiting_task_ids.insert(id_task); } -void server_response::add_waiting_tasks(const std::vector & tasks) { +void server_response::add_waiting_task_ids(const std::unordered_set & id_tasks) { std::unique_lock lock(mutex_results); - for (const auto & task : tasks) { - RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); + for (const auto & id_task : id_tasks) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.insert(id_task); } } @@ -327,6 +342,7 @@ void server_response::terminate() { void server_response_reader::post_task(server_task && task, bool front) { GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead"); task.index = 0; id_tasks.insert(task.id); states.push_back(task.create_state()); @@ -338,11 +354,18 @@ void server_response_reader::post_tasks(std::vector && tasks, bool GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); states.reserve(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - tasks[i].index = i; - states.push_back(tasks[i].create_state()); + size_t index = 0; + for (auto & task : tasks) { + task.index = index++; + states.push_back(task.create_state()); + // for child tasks + for (auto & child_task : task.child_tasks) { + child_task.index = index++; + states.push_back(child_task.create_state()); + } } - queue_results.add_waiting_tasks(tasks); + GGML_ASSERT(states.size() == id_tasks.size()); + queue_results.add_waiting_task_ids(id_tasks); queue_tasks.post(std::move(tasks), front); } diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 3798aa299e..164f09b195 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -44,7 +44,8 @@ public: int get_new_id(); // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task(); + // prioritize tasks that use the specified slot (otherwise, pop the first deferred task) + void pop_deferred_task(int id_slot); // if sleeping, request exiting sleep state and wait until it is done // returns immediately if not sleeping @@ -124,7 +125,7 @@ public: // add the id_task to the list of tasks waiting for response void add_waiting_task_id(int id_task); - void add_waiting_tasks(const std::vector & tasks); + void add_waiting_task_ids(const std::unordered_set & id_tasks); // when the request is finished, we can remove task associated with it void remove_waiting_task_id(int id_task); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 97bae920d6..11943ee4f8 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -121,8 +121,10 @@ struct server_task { int id_slot = -1; // used by parallel sampling (multiple completions from same prompt) - int n_children = 0; // number of tasks reusing this prompt int id_parent = -1; + // temporary store of child tasks for scheduling + // note: accessing to elements is invalid after the task is moved to server_slot + std::vector child_tasks; // used by SERVER_TASK_TYPE_INFERENCE task_params params; @@ -197,11 +199,14 @@ struct server_task { std::unordered_set ids(tasks.size()); for (size_t i = 0; i < tasks.size(); i++) { ids.insert(tasks[i].id); + for (auto & child : tasks[i].child_tasks) { + ids.insert(child.id); + } } return ids; } - server_task create_child(int id_parent, int id_child) const { + void add_child(int id_parent, int id_child) { server_task copy; copy.id = id_child; @@ -209,8 +214,15 @@ struct server_task { copy.params = params; copy.type = type; copy.tokens = tokens.clone(); + copy.id_slot = -1; // child tasks cannot specify slot - return copy; + // use different sampling seed for each child + // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723 + if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) { + copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1; + } + + child_tasks.push_back(std::move(copy)); } // the task will be moved into queue, then onto slots @@ -218,6 +230,14 @@ struct server_task { task_result_state create_state() const { return task_result_state(params.oaicompat_chat_syntax); } + + bool is_parent() const { + return child_tasks.size() > 0; + } + + bool is_child() const { + return id_parent != -1; + } }; struct result_timings { diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index d0ce01bc6e..d56a930f7c 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -491,16 +491,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache): def test_chat_completions_multiple_choices(): global server server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 8, - "n": 2, - "messages": [ - {"role": "system", "content": "Book"}, - {"role": "user", "content": "What is the best book"}, - ], - }) - assert res.status_code == 200 - assert len(res.body["choices"]) == 2 - for choice in res.body["choices"]: - assert "assistant" == choice["message"]["role"] - assert choice["finish_reason"] == "length" + # make sure cache can be reused across multiple choices and multiple requests + # ref: https://github.com/ggml-org/llama.cpp/pull/18663 + for _ in range(2): + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "n": 2, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + # test forcing the same slot to be used + # the scheduler should not be locked up in this case + "id_slot": 0, + }) + assert res.status_code == 200 + assert len(res.body["choices"]) == 2 + for choice in res.body["choices"]: + assert "assistant" == choice["message"]["role"] + assert choice["finish_reason"] == "length"