server: delegate result_state creation to server_task (#17835)

* server: delegate result_state creation to server_task

* remove unued states

* add more docs
This commit is contained in:
Xuan-Son Nguyen 2025-12-08 17:04:38 +01:00 committed by GitHub
parent 68522c678d
commit 951520ddb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 40 deletions

View File

@ -42,7 +42,15 @@ graph TD
server_response --> server_routes server_response --> server_routes
``` ```
TODO: mention about how batching is handled by `server_slot` ### Batching
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
### Thread Management ### Thread Management
@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr
- All JSON formatting and chat template logic must stay in the HTTP layer. - All JSON formatting and chat template logic must stay in the HTTP layer.
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible. - Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
### Example trace of a request
Here is an example trace of an API request for text completion:
- A request arrives at the HTTP layer.
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
- `server_res_generator` creates a new `task_result_state` for each task:
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
- `server_task` is moved into `server_queue` inside `server_context`.
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
- `update_slot()` processes the task as described in the "Batching" section above.
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
### Testing ### Testing
`llama-server` includes an automated test suite based on `pytest`. `llama-server` includes an automated test suite based on `pytest`.

View File

@ -2589,6 +2589,10 @@ struct server_context_impl {
int get_slot_n_ctx() { int get_slot_n_ctx() {
return slots.back().n_ctx; return slots.back().n_ctx;
} }
server_response_reader get_response_reader() {
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
}
}; };
// //
@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
return impl->ctx; return impl->ctx;
} }
std::pair<server_queue &, server_response &> server_context::get_queues() { server_response_reader server_context::get_response_reader() {
return { impl->queue_tasks, impl->queue_results }; return impl->get_response_reader();
} }
@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
struct server_res_generator : server_http_res { struct server_res_generator : server_http_res {
server_response_reader rd; server_response_reader rd;
server_res_generator(server_context_impl & ctx_server) server_res_generator(server_context_impl & ctx_server)
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
void ok(const json & response_data) { void ok(const json & response_data) {
status = 200; status = 200;
data = safe_json_to_str(response_data); data = safe_json_to_str(response_data);
@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try { try {
std::vector<server_task> tasks; std::vector<server_task> tasks;
// tracking generation state and partial tool calls
std::vector<task_result_state> states;
const auto & prompt = data.at("prompt"); const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format // TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
} }
tasks.reserve(inputs.size()); tasks.reserve(inputs.size());
states.reserve(inputs.size());
int idx = 0; int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);
@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type; task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name; task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
if (task.params.n_cmpl > 1) { if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1; task.n_children = task.params.n_cmpl - 1;
@ -2707,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.id, task.id,
ctx_server.queue_tasks.get_new_id(), ctx_server.queue_tasks.get_new_id(),
idx++); idx++);
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child)); tasks.push_back(std::move(child));
} }
} }
@ -2715,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
} }
rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks)); rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) { } catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); server_response_reader rd = ctx_server.get_response_reader();
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
tasks.reserve(documents.size()); tasks.reserve(documents.size());
@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); server_response_reader rd = ctx_server.get_response_reader();
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {

View File

@ -31,9 +31,8 @@ struct server_context {
// get the underlaying llama_context // get the underlaying llama_context
llama_context * get_llama_context() const; llama_context * get_llama_context() const;
// get the underlaying queue_tasks and queue_results // get a new response reader, used by CLI application
// used by CLI application server_response_reader get_response_reader();
std::pair<server_queue &, server_response &> get_queues();
}; };

View File

@ -271,12 +271,21 @@ void server_response::terminate() {
// server_response_reader // server_response_reader
// //
void server_response_reader::set_states(std::vector<task_result_state> && states) { void server_response_reader::post_task(server_task && task) {
this->states = std::move(states); GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
id_tasks.insert(task.id);
states.push_back(task.create_state());
queue_results.add_waiting_task_id(task.id);
queue_tasks.post(std::move(task));
} }
void server_response_reader::post_tasks(std::vector<server_task> && tasks) { void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks); id_tasks = server_task::get_list_id(tasks);
states.reserve(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
states.push_back(tasks[i].create_state());
}
queue_results.add_waiting_tasks(tasks); queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks)); queue_tasks.post(std::move(tasks));
} }

View File

@ -129,13 +129,13 @@ struct server_response_reader {
std::vector<task_result_state> states; std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds // should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds) server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() { ~server_response_reader() {
stop(); stop();
} }
void set_states(std::vector<task_result_state> && states); void post_task(server_task && tasks);
void post_tasks(std::vector<server_task> && tasks); void post_tasks(std::vector<server_task> && tasks);
bool has_next() const; bool has_next() const;

View File

@ -85,6 +85,25 @@ struct task_params {
json to_json(bool only_metrics = false) const; json to_json(bool only_metrics = false) const;
}; };
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task { struct server_task {
int id = -1; // to be filled by server_queue int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request) int index = -1; // used when there are multiple prompts (batch request)
@ -149,6 +168,12 @@ struct server_task {
copy.tokens = tokens.clone(); copy.tokens = tokens.clone();
return copy; return copy;
} }
// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
}
}; };
struct result_timings { struct result_timings {
@ -180,25 +205,6 @@ struct result_prompt_progress {
json to_json() const; json to_json() const;
}; };
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task_result { struct server_task_result {
int id = -1; int id = -1;
int id_slot = -1; int id_slot = -1;