server: move msg diffs tracking to HTTP thread (#17740)

* server: move msg diffs tracking to HTTP thread

* wip

* tool call tests ok

* minor : style

* cont : fix

* move states to server_response_reader

* add safe-guard

* fix

* fix 2

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen 2025-12-04 15:46:08 +01:00 committed by GitHub
parent 817d743cc1
commit c4c10bfb86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 167 additions and 94 deletions

View File

@ -101,8 +101,6 @@ struct server_slot {
std::string generated_text; std::string generated_text;
llama_tokens generated_tokens; llama_tokens generated_tokens;
common_chat_msg chat_msg;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true; bool has_next_token = true;
@ -153,9 +151,6 @@ struct server_slot {
llama_token sampled; llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats // stats
size_t n_sent_text = 0; // number of sent text character size_t n_sent_text = 0; // number of sent text character
@ -183,13 +178,10 @@ struct server_slot {
stop = STOP_TYPE_NONE; stop = STOP_TYPE_NONE;
stopping_word = ""; stopping_word = "";
n_sent_text = 0; n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear(); generated_tokens.clear();
generated_token_probs.clear(); generated_token_probs.clear();
chat_msg = {};
json_schema = json(); json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats // clear speculative decoding stats
n_draft_total = 0; n_draft_total = 0;
@ -302,23 +294,6 @@ struct server_slot {
return timings; return timings;
} }
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
GGML_ASSERT(task);
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
task->params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
GGML_ASSERT(task); GGML_ASSERT(task);
@ -1284,8 +1259,6 @@ struct server_context_impl {
} else { } else {
res->content = tkn.text_to_send; res->content = tkn.text_to_send;
res->tokens = { tkn.tok }; res->tokens = { tkn.tok };
slot.update_chat_msg(res->oaicompat_msg_diffs);
} }
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
@ -1317,8 +1290,14 @@ struct server_context_impl {
res->id_slot = slot.id; res->id_slot = slot.id;
res->index = slot.task->index; res->index = slot.task->index;
res->content = slot.generated_text; // in stream mode, content and tokens are already in last partial chunk
res->tokens = std::move(slot.generated_tokens); if (slot.task->params.stream) {
res->content = "";
res->tokens = llama_tokens{};
} else {
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
}
res->timings = slot.get_timings(); res->timings = slot.get_timings();
res->prompt = slot.task->tokens.detokenize(ctx, true); res->prompt = slot.task->tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.task->params.response_fields); res->response_fields = std::move(slot.task->params.response_fields);
@ -1338,7 +1317,6 @@ struct server_context_impl {
res->res_type = slot.task->params.res_type; res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model; res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output // populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) { if (slot.task->params.sampling.n_probs > 0) {
@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try { try {
std::vector<server_task> tasks; std::vector<server_task> tasks;
// tracking generation state and partial tool calls
std::vector<task_result_state> states;
const auto & prompt = data.at("prompt"); const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format // TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
} }
tasks.reserve(inputs.size()); tasks.reserve(inputs.size());
states.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);
@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type; task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name; task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
} }
rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks)); rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) { } catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
// if single request, return single object instead of array // if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr); res->ok(arr.size() == 1 ? arr[0] : arr);
} }
} else { } else {
// in streaming mode, the first error must be treated as non-stream response // in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior // this is to match the OAI API behavior
@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
} }
// next responses are streamed // next responses are streamed
// to be sent immediately
json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result->to_json()); res->data = format_anthropic_sse(first_result_json);
} else { } else {
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately res->data = format_oai_sse(first_result_json);
} }
res->status = 200; res->status = 200;
res->content_type = "text/event-stream"; res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
if (should_stop()) { static auto format_error = [](task_response_type res_type, const json & res_json) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent return format_anthropic_sse({
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
json res_json = result->to_json();
if (result->is_error()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse({
{"event", "error"}, {"event", "error"},
{"data", res_json}, {"data", res_json},
}); });
} else { } else {
output = format_oai_sse(json {{ "error", res_json }}); return format_oai_sse(json {{ "error", res_json }});
} }
SRV_DBG("%s", "error received during streaming, terminating stream\n"); };
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
// has next data, continue try {
return true; if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
if (!res_this->data.empty()) {
// flush the first chunk
output = std::move(res_this->data);
res_this->data.clear();
return true;
}
server_response_reader & rd = res_this->rd;
// check if there is more data
if (!rd.has_next()) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
}
SRV_DBG("%s", "all results received, terminating stream\n");
return false; // no more data, terminate
}
// receive subsequent results
auto result = rd.next(should_stop);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
}
// send the results
if (result->is_error()) {
json res_json = result->to_json();
output = format_error(res_type, res_json);
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
json res_json = result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
// has next data, continue
return true;
} catch (const std::exception & e) {
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
output = format_error(res_type, error_json);
// terminate on exception
return false;
}
}; };
} }

View File

@ -271,6 +271,10 @@ void server_response::terminate() {
// server_response_reader // server_response_reader
// //
void server_response_reader::set_states(std::vector<task_result_state> && states) {
this->states = std::move(states);
}
void server_response_reader::post_tasks(std::vector<server_task> && tasks) { void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
id_tasks = server_task::get_list_id(tasks); id_tasks = server_task::get_list_id(tasks);
queue_results.add_waiting_tasks(tasks); queue_results.add_waiting_tasks(tasks);
@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
SRV_DBG("%s", "received error result, stopping further processing\n"); SRV_DBG("%s", "received error result, stopping further processing\n");
return result; return result;
} }
if (!states.empty()) {
// update the generation state if needed
size_t idx = result->get_index();
GGML_ASSERT(idx < states.size());
result->update(states[idx]);
}
if (result->is_stop()) { if (result->is_stop()) {
received_count++; received_count++;
} }

View File

@ -7,6 +7,8 @@
#include <mutex> #include <mutex>
#include <unordered_set> #include <unordered_set>
// struct for managing server tasks
// in most cases, use server_response_reader to post new tasks and retrieve results
struct server_queue { struct server_queue {
private: private:
int id = 0; int id = 0;
@ -67,6 +69,8 @@ private:
void cleanup_pending_task(int id_target); void cleanup_pending_task(int id_target);
}; };
// struct for managing server responses
// in most cases, use server_response_reader to retrieve results
struct server_response { struct server_response {
private: private:
bool running = true; bool running = true;
@ -120,6 +124,10 @@ struct server_response_reader {
bool cancelled = false; bool cancelled = false;
int polling_interval_seconds; int polling_interval_seconds;
// tracking generation state and partial tool calls
// only used by streaming completions
std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds // should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds) server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
@ -127,6 +135,7 @@ struct server_response_reader {
stop(); stop();
} }
void set_states(std::vector<task_result_state> && states);
void post_tasks(std::vector<server_task> && tasks); void post_tasks(std::vector<server_task> && tasks);
bool has_next() const; bool has_next() const;

View File

@ -565,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
// server_task_result_cmpl_final // server_task_result_cmpl_final
// //
json server_task_result_cmpl_final::to_json() { json server_task_result_cmpl_final::to_json() {
GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) { switch (res_type) {
case TASK_RESPONSE_TYPE_NONE: case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat(); return to_json_non_oaicompat();
@ -582,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
json server_task_result_cmpl_final::to_json_non_oaicompat() { json server_task_result_cmpl_final::to_json_non_oaicompat() {
json res = json { json res = json {
{"index", index}, {"index", index},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk {"content", content},
{"tokens", stream ? llama_tokens {} : tokens}, {"tokens", tokens},
{"id_slot", id_slot}, {"id_slot", id_slot},
{"stop", true}, {"stop", true},
{"model", oaicompat_model}, {"model", oaicompat_model},
@ -619,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
json res = json { json res = json {
{"choices", json::array({ {"choices", json::array({
json{ json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk {"text", content},
{"index", index}, {"index", index},
{"logprobs", logprobs}, {"logprobs", logprobs},
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
@ -700,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
return res; return res;
} }
common_chat_msg task_result_state::update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs) {
generated_text += text_added;
auto msg_prv_copy = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
is_partial,
oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
}
return chat_msg;
}
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0); std::time_t t = std::time(0);
std::string finish_reason = "length"; std::string finish_reason = "length";
@ -956,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
// server_task_result_cmpl_partial // server_task_result_cmpl_partial
// //
json server_task_result_cmpl_partial::to_json() { json server_task_result_cmpl_partial::to_json() {
GGML_ASSERT(is_updated && "update() must be called before to_json()");
switch (res_type) { switch (res_type) {
case TASK_RESPONSE_TYPE_NONE: case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat(); return to_json_non_oaicompat();

View File

@ -161,6 +161,25 @@ struct result_prompt_progress {
json to_json() const; json to_json() const;
}; };
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task_result { struct server_task_result {
int id = -1; int id = -1;
int id_slot = -1; int id_slot = -1;
@ -175,6 +194,9 @@ struct server_task_result {
virtual int get_index() { virtual int get_index() {
return -1; return -1;
} }
virtual void update(task_result_state &) {
// only used by server_task_result_cmpl_*
}
virtual json to_json() = 0; virtual json to_json() = 0;
virtual ~server_task_result() = default; virtual ~server_task_result() = default;
}; };
@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
task_response_type res_type = TASK_RESPONSE_TYPE_NONE; task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg; common_chat_msg oaicompat_msg; // to be populated by update()
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
virtual json to_json() override; virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
}
json to_json_non_oaicompat(); json to_json_non_oaicompat();
json to_json_oaicompat(); json to_json_oaicompat();
@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
task_response_type res_type = TASK_RESPONSE_TYPE_NONE; task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
virtual json to_json() override; virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
}
json to_json_non_oaicompat(); json to_json_non_oaicompat();
json to_json_oaicompat(); json to_json_oaicompat();