server: delegate result_state creation to server_task (#17835)
* server: delegate result_state creation to server_task * remove unued states * add more docs
This commit is contained in:
parent
68522c678d
commit
951520ddb0
|
|
@ -42,7 +42,15 @@ graph TD
|
||||||
server_response --> server_routes
|
server_response --> server_routes
|
||||||
```
|
```
|
||||||
|
|
||||||
TODO: mention about how batching is handled by `server_slot`
|
### Batching
|
||||||
|
|
||||||
|
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
|
||||||
|
|
||||||
|
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
|
||||||
|
|
||||||
|
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
|
||||||
|
|
||||||
|
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
|
||||||
|
|
||||||
### Thread Management
|
### Thread Management
|
||||||
|
|
||||||
|
|
@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr
|
||||||
- All JSON formatting and chat template logic must stay in the HTTP layer.
|
- All JSON formatting and chat template logic must stay in the HTTP layer.
|
||||||
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
|
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
|
||||||
|
|
||||||
|
### Example trace of a request
|
||||||
|
|
||||||
|
Here is an example trace of an API request for text completion:
|
||||||
|
|
||||||
|
- A request arrives at the HTTP layer.
|
||||||
|
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
|
||||||
|
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
|
||||||
|
- `server_res_generator` creates a new `task_result_state` for each task:
|
||||||
|
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
|
||||||
|
- `server_task` is moved into `server_queue` inside `server_context`.
|
||||||
|
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
|
||||||
|
- `update_slot()` processes the task as described in the "Batching" section above.
|
||||||
|
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
|
||||||
|
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
|
||||||
|
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
|
||||||
|
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
|
||||||
|
|
||||||
### Testing
|
### Testing
|
||||||
|
|
||||||
`llama-server` includes an automated test suite based on `pytest`.
|
`llama-server` includes an automated test suite based on `pytest`.
|
||||||
|
|
|
||||||
|
|
@ -2589,6 +2589,10 @@ struct server_context_impl {
|
||||||
int get_slot_n_ctx() {
|
int get_slot_n_ctx() {
|
||||||
return slots.back().n_ctx;
|
return slots.back().n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_response_reader get_response_reader() {
|
||||||
|
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
|
||||||
return impl->ctx;
|
return impl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<server_queue &, server_response &> server_context::get_queues() {
|
server_response_reader server_context::get_response_reader() {
|
||||||
return { impl->queue_tasks, impl->queue_results };
|
return impl->get_response_reader();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
|
||||||
struct server_res_generator : server_http_res {
|
struct server_res_generator : server_http_res {
|
||||||
server_response_reader rd;
|
server_response_reader rd;
|
||||||
server_res_generator(server_context_impl & ctx_server)
|
server_res_generator(server_context_impl & ctx_server)
|
||||||
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
|
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
|
||||||
void ok(const json & response_data) {
|
void ok(const json & response_data) {
|
||||||
status = 200;
|
status = 200;
|
||||||
data = safe_json_to_str(response_data);
|
data = safe_json_to_str(response_data);
|
||||||
|
|
@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
try {
|
try {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
// tracking generation state and partial tool calls
|
|
||||||
std::vector<task_result_state> states;
|
|
||||||
|
|
||||||
const auto & prompt = data.at("prompt");
|
const auto & prompt = data.at("prompt");
|
||||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||||
|
|
@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
states.reserve(inputs.size());
|
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
|
@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
task.params.res_type = res_type;
|
task.params.res_type = res_type;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
task.params.oaicompat_model = ctx_server.model_name;
|
task.params.oaicompat_model = ctx_server.model_name;
|
||||||
states.push_back(task.params.oaicompat_chat_syntax);
|
|
||||||
|
|
||||||
if (task.params.n_cmpl > 1) {
|
if (task.params.n_cmpl > 1) {
|
||||||
task.n_children = task.params.n_cmpl - 1;
|
task.n_children = task.params.n_cmpl - 1;
|
||||||
|
|
@ -2707,7 +2706,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
task.id,
|
task.id,
|
||||||
ctx_server.queue_tasks.get_new_id(),
|
ctx_server.queue_tasks.get_new_id(),
|
||||||
idx++);
|
idx++);
|
||||||
states.push_back(child.params.oaicompat_chat_syntax);
|
|
||||||
tasks.push_back(std::move(child));
|
tasks.push_back(std::move(child));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2715,7 +2713,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
rd.set_states(std::move(states));
|
|
||||||
rd.post_tasks(std::move(tasks));
|
rd.post_tasks(std::move(tasks));
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
|
@ -3445,7 +3442,7 @@ void server_routes::init_routes() {
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
server_response_reader rd = ctx_server.get_response_reader();
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
tasks.reserve(documents.size());
|
tasks.reserve(documents.size());
|
||||||
|
|
@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
|
server_response_reader rd = ctx_server.get_response_reader();
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,8 @@ struct server_context {
|
||||||
// get the underlaying llama_context
|
// get the underlaying llama_context
|
||||||
llama_context * get_llama_context() const;
|
llama_context * get_llama_context() const;
|
||||||
|
|
||||||
// get the underlaying queue_tasks and queue_results
|
// get a new response reader, used by CLI application
|
||||||
// used by CLI application
|
server_response_reader get_response_reader();
|
||||||
std::pair<server_queue &, server_response &> get_queues();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -271,12 +271,21 @@ void server_response::terminate() {
|
||||||
// server_response_reader
|
// server_response_reader
|
||||||
//
|
//
|
||||||
|
|
||||||
void server_response_reader::set_states(std::vector<task_result_state> && states) {
|
void server_response_reader::post_task(server_task && task) {
|
||||||
this->states = std::move(states);
|
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
||||||
|
id_tasks.insert(task.id);
|
||||||
|
states.push_back(task.create_state());
|
||||||
|
queue_results.add_waiting_task_id(task.id);
|
||||||
|
queue_tasks.post(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||||
|
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
||||||
id_tasks = server_task::get_list_id(tasks);
|
id_tasks = server_task::get_list_id(tasks);
|
||||||
|
states.reserve(tasks.size());
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
states.push_back(tasks[i].create_state());
|
||||||
|
}
|
||||||
queue_results.add_waiting_tasks(tasks);
|
queue_results.add_waiting_tasks(tasks);
|
||||||
queue_tasks.post(std::move(tasks));
|
queue_tasks.post(std::move(tasks));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -129,13 +129,13 @@ struct server_response_reader {
|
||||||
std::vector<task_result_state> states;
|
std::vector<task_result_state> states;
|
||||||
|
|
||||||
// should_stop function will be called each polling_interval_seconds
|
// should_stop function will be called each polling_interval_seconds
|
||||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
|
||||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
|
||||||
~server_response_reader() {
|
~server_response_reader() {
|
||||||
stop();
|
stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_states(std::vector<task_result_state> && states);
|
void post_task(server_task && tasks);
|
||||||
void post_tasks(std::vector<server_task> && tasks);
|
void post_tasks(std::vector<server_task> && tasks);
|
||||||
bool has_next() const;
|
bool has_next() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,25 @@ struct task_params {
|
||||||
json to_json(bool only_metrics = false) const;
|
json to_json(bool only_metrics = false) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// struct for tracking the state of a task (e.g., for streaming)
|
||||||
|
struct task_result_state {
|
||||||
|
// tracking diffs for partial tool calls
|
||||||
|
std::vector<common_chat_msg_diff> diffs;
|
||||||
|
common_chat_syntax oaicompat_chat_syntax;
|
||||||
|
common_chat_msg chat_msg;
|
||||||
|
std::string generated_text; // append new chunks of generated text here
|
||||||
|
std::vector<std::string> generated_tool_call_ids;
|
||||||
|
|
||||||
|
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||||
|
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||||
|
|
||||||
|
// parse partial tool calls and update the internal state
|
||||||
|
common_chat_msg update_chat_msg(
|
||||||
|
const std::string & text_added,
|
||||||
|
bool is_partial,
|
||||||
|
std::vector<common_chat_msg_diff> & diffs);
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
int index = -1; // used when there are multiple prompts (batch request)
|
int index = -1; // used when there are multiple prompts (batch request)
|
||||||
|
|
@ -149,6 +168,12 @@ struct server_task {
|
||||||
copy.tokens = tokens.clone();
|
copy.tokens = tokens.clone();
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the task will be moved into queue, then onto slots
|
||||||
|
// however, the state must be kept by caller (e.g., HTTP thread)
|
||||||
|
task_result_state create_state() const {
|
||||||
|
return task_result_state(params.oaicompat_chat_syntax);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct result_timings {
|
struct result_timings {
|
||||||
|
|
@ -180,25 +205,6 @@ struct result_prompt_progress {
|
||||||
json to_json() const;
|
json to_json() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// struct for tracking the state of a task (e.g., for streaming)
|
|
||||||
struct task_result_state {
|
|
||||||
// tracking diffs for partial tool calls
|
|
||||||
std::vector<common_chat_msg_diff> diffs;
|
|
||||||
common_chat_syntax oaicompat_chat_syntax;
|
|
||||||
common_chat_msg chat_msg;
|
|
||||||
std::string generated_text; // append new chunks of generated text here
|
|
||||||
std::vector<std::string> generated_tool_call_ids;
|
|
||||||
|
|
||||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
|
||||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
|
||||||
|
|
||||||
// parse partial tool calls and update the internal state
|
|
||||||
common_chat_msg update_chat_msg(
|
|
||||||
const std::string & text_added,
|
|
||||||
bool is_partial,
|
|
||||||
std::vector<common_chat_msg_diff> & diffs);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct server_task_result {
|
struct server_task_result {
|
||||||
int id = -1;
|
int id = -1;
|
||||||
int id_slot = -1;
|
int id_slot = -1;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue