server: move msg diffs tracking to HTTP thread (#17740)
* server: move msg diffs tracking to HTTP thread * wip * tool call tests ok * minor : style * cont : fix * move states to server_response_reader * add safe-guard * fix * fix 2 --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
817d743cc1
commit
c4c10bfb86
|
|
@ -101,8 +101,6 @@ struct server_slot {
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
|
|
||||||
common_chat_msg chat_msg;
|
|
||||||
|
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
|
|
@ -153,9 +151,6 @@ struct server_slot {
|
||||||
|
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
|
||||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
||||||
std::vector<std::string> generated_tool_call_ids;
|
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
|
|
||||||
|
|
@ -183,13 +178,10 @@ struct server_slot {
|
||||||
stop = STOP_TYPE_NONE;
|
stop = STOP_TYPE_NONE;
|
||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
||||||
|
|
||||||
generated_tokens.clear();
|
generated_tokens.clear();
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
chat_msg = {};
|
|
||||||
json_schema = json();
|
json_schema = json();
|
||||||
generated_tool_call_ids.clear();
|
|
||||||
|
|
||||||
// clear speculative decoding stats
|
// clear speculative decoding stats
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
|
|
@ -302,23 +294,6 @@ struct server_slot {
|
||||||
return timings;
|
return timings;
|
||||||
}
|
}
|
||||||
|
|
||||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
|
||||||
GGML_ASSERT(task);
|
|
||||||
|
|
||||||
auto previous_msg = chat_msg;
|
|
||||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
|
||||||
auto new_msg = common_chat_parse(
|
|
||||||
generated_text,
|
|
||||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
|
||||||
task->params.oaicompat_chat_syntax);
|
|
||||||
if (!new_msg.empty()) {
|
|
||||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
|
||||||
chat_msg = new_msg;
|
|
||||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
|
||||||
}
|
|
||||||
return chat_msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||||
GGML_ASSERT(task);
|
GGML_ASSERT(task);
|
||||||
|
|
||||||
|
|
@ -1284,8 +1259,6 @@ struct server_context_impl {
|
||||||
} else {
|
} else {
|
||||||
res->content = tkn.text_to_send;
|
res->content = tkn.text_to_send;
|
||||||
res->tokens = { tkn.tok };
|
res->tokens = { tkn.tok };
|
||||||
|
|
||||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
|
|
@ -1317,8 +1290,14 @@ struct server_context_impl {
|
||||||
res->id_slot = slot.id;
|
res->id_slot = slot.id;
|
||||||
|
|
||||||
res->index = slot.task->index;
|
res->index = slot.task->index;
|
||||||
res->content = slot.generated_text;
|
// in stream mode, content and tokens are already in last partial chunk
|
||||||
res->tokens = std::move(slot.generated_tokens);
|
if (slot.task->params.stream) {
|
||||||
|
res->content = "";
|
||||||
|
res->tokens = llama_tokens{};
|
||||||
|
} else {
|
||||||
|
res->content = std::move(slot.generated_text);
|
||||||
|
res->tokens = std::move(slot.generated_tokens);
|
||||||
|
}
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
||||||
res->response_fields = std::move(slot.task->params.response_fields);
|
res->response_fields = std::move(slot.task->params.response_fields);
|
||||||
|
|
@ -1338,7 +1317,6 @@ struct server_context_impl {
|
||||||
res->res_type = slot.task->params.res_type;
|
res->res_type = slot.task->params.res_type;
|
||||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.task->params.sampling.n_probs > 0) {
|
if (slot.task->params.sampling.n_probs > 0) {
|
||||||
|
|
@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
try {
|
try {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
|
// tracking generation state and partial tool calls
|
||||||
|
std::vector<task_result_state> states;
|
||||||
|
|
||||||
const auto & prompt = data.at("prompt");
|
const auto & prompt = data.at("prompt");
|
||||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||||
|
|
@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
|
states.reserve(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
|
||||||
|
|
@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
task.params.res_type = res_type;
|
task.params.res_type = res_type;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
task.params.oaicompat_model = ctx_server.model_name;
|
task.params.oaicompat_model = ctx_server.model_name;
|
||||||
|
states.push_back(task.params.oaicompat_chat_syntax);
|
||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rd.set_states(std::move(states));
|
||||||
rd.post_tasks(std::move(tasks));
|
rd.post_tasks(std::move(tasks));
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
|
@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
// if single request, return single object instead of array
|
// if single request, return single object instead of array
|
||||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// in streaming mode, the first error must be treated as non-stream response
|
// in streaming mode, the first error must be treated as non-stream response
|
||||||
// this is to match the OAI API behavior
|
// this is to match the OAI API behavior
|
||||||
|
|
@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
// next responses are streamed
|
// next responses are streamed
|
||||||
|
// to be sent immediately
|
||||||
|
json first_result_json = first_result->to_json();
|
||||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
res->data = format_anthropic_sse(first_result->to_json());
|
res->data = format_anthropic_sse(first_result_json);
|
||||||
} else {
|
} else {
|
||||||
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
res->data = format_oai_sse(first_result_json);
|
||||||
}
|
}
|
||||||
res->status = 200;
|
res->status = 200;
|
||||||
res->content_type = "text/event-stream";
|
res->content_type = "text/event-stream";
|
||||||
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||||
if (should_stop()) {
|
static auto format_error = [](task_response_type res_type, const json & res_json) {
|
||||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
||||||
return false; // should_stop condition met
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!res_this->data.empty()) {
|
|
||||||
// flush the first chunk
|
|
||||||
output = std::move(res_this->data);
|
|
||||||
res_this->data.clear();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
server_response_reader & rd = res_this->rd;
|
|
||||||
|
|
||||||
// check if there is more data
|
|
||||||
if (!rd.has_next()) {
|
|
||||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
return format_anthropic_sse({
|
||||||
output = "";
|
|
||||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
|
||||||
output = "data: [DONE]\n\n";
|
|
||||||
} else {
|
|
||||||
output = "";
|
|
||||||
}
|
|
||||||
SRV_DBG("%s", "all results received, terminating stream\n");
|
|
||||||
return false; // no more data, terminate
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive subsequent results
|
|
||||||
auto result = rd.next(should_stop);
|
|
||||||
if (result == nullptr) {
|
|
||||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
|
||||||
return false; // should_stop condition met
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the results
|
|
||||||
json res_json = result->to_json();
|
|
||||||
if (result->is_error()) {
|
|
||||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
||||||
output = format_anthropic_sse({
|
|
||||||
{"event", "error"},
|
{"event", "error"},
|
||||||
{"data", res_json},
|
{"data", res_json},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
output = format_oai_sse(json {{ "error", res_json }});
|
return format_oai_sse(json {{ "error", res_json }});
|
||||||
}
|
}
|
||||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
};
|
||||||
return false; // terminate on error
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(
|
|
||||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
||||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
||||||
);
|
|
||||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
|
||||||
output = format_anthropic_sse(res_json);
|
|
||||||
} else {
|
|
||||||
output = format_oai_sse(res_json);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// has next data, continue
|
try {
|
||||||
return true;
|
if (should_stop()) {
|
||||||
|
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||||
|
return false; // should_stop condition met
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!res_this->data.empty()) {
|
||||||
|
// flush the first chunk
|
||||||
|
output = std::move(res_this->data);
|
||||||
|
res_this->data.clear();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
server_response_reader & rd = res_this->rd;
|
||||||
|
|
||||||
|
// check if there is more data
|
||||||
|
if (!rd.has_next()) {
|
||||||
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||||
|
output = "";
|
||||||
|
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||||
|
output = "data: [DONE]\n\n";
|
||||||
|
} else {
|
||||||
|
output = "";
|
||||||
|
}
|
||||||
|
SRV_DBG("%s", "all results received, terminating stream\n");
|
||||||
|
return false; // no more data, terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
// receive subsequent results
|
||||||
|
auto result = rd.next(should_stop);
|
||||||
|
if (result == nullptr) {
|
||||||
|
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||||
|
return false; // should_stop condition met
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the results
|
||||||
|
if (result->is_error()) {
|
||||||
|
json res_json = result->to_json();
|
||||||
|
output = format_error(res_type, res_json);
|
||||||
|
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||||
|
return false; // terminate on error
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(
|
||||||
|
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||||
|
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||||
|
);
|
||||||
|
json res_json = result->to_json();
|
||||||
|
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||||
|
output = format_anthropic_sse(res_json);
|
||||||
|
} else {
|
||||||
|
output = format_oai_sse(res_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// has next data, continue
|
||||||
|
return true;
|
||||||
|
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
|
||||||
|
output = format_error(res_type, error_json);
|
||||||
|
|
||||||
|
// terminate on exception
|
||||||
|
return false;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,10 @@ void server_response::terminate() {
|
||||||
// server_response_reader
|
// server_response_reader
|
||||||
//
|
//
|
||||||
|
|
||||||
|
void server_response_reader::set_states(std::vector<task_result_state> && states) {
|
||||||
|
this->states = std::move(states);
|
||||||
|
}
|
||||||
|
|
||||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||||
id_tasks = server_task::get_list_id(tasks);
|
id_tasks = server_task::get_list_id(tasks);
|
||||||
queue_results.add_waiting_tasks(tasks);
|
queue_results.add_waiting_tasks(tasks);
|
||||||
|
|
@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
|
||||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
if (!states.empty()) {
|
||||||
|
// update the generation state if needed
|
||||||
|
size_t idx = result->get_index();
|
||||||
|
GGML_ASSERT(idx < states.size());
|
||||||
|
result->update(states[idx]);
|
||||||
|
}
|
||||||
if (result->is_stop()) {
|
if (result->is_stop()) {
|
||||||
received_count++;
|
received_count++;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
// struct for managing server tasks
|
||||||
|
// in most cases, use server_response_reader to post new tasks and retrieve results
|
||||||
struct server_queue {
|
struct server_queue {
|
||||||
private:
|
private:
|
||||||
int id = 0;
|
int id = 0;
|
||||||
|
|
@ -67,6 +69,8 @@ private:
|
||||||
void cleanup_pending_task(int id_target);
|
void cleanup_pending_task(int id_target);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// struct for managing server responses
|
||||||
|
// in most cases, use server_response_reader to retrieve results
|
||||||
struct server_response {
|
struct server_response {
|
||||||
private:
|
private:
|
||||||
bool running = true;
|
bool running = true;
|
||||||
|
|
@ -120,6 +124,10 @@ struct server_response_reader {
|
||||||
bool cancelled = false;
|
bool cancelled = false;
|
||||||
int polling_interval_seconds;
|
int polling_interval_seconds;
|
||||||
|
|
||||||
|
// tracking generation state and partial tool calls
|
||||||
|
// only used by streaming completions
|
||||||
|
std::vector<task_result_state> states;
|
||||||
|
|
||||||
// should_stop function will be called each polling_interval_seconds
|
// should_stop function will be called each polling_interval_seconds
|
||||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||||
|
|
@ -127,6 +135,7 @@ struct server_response_reader {
|
||||||
stop();
|
stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_states(std::vector<task_result_state> && states);
|
||||||
void post_tasks(std::vector<server_task> && tasks);
|
void post_tasks(std::vector<server_task> && tasks);
|
||||||
bool has_next() const;
|
bool has_next() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -565,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
||||||
// server_task_result_cmpl_final
|
// server_task_result_cmpl_final
|
||||||
//
|
//
|
||||||
json server_task_result_cmpl_final::to_json() {
|
json server_task_result_cmpl_final::to_json() {
|
||||||
|
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||||
switch (res_type) {
|
switch (res_type) {
|
||||||
case TASK_RESPONSE_TYPE_NONE:
|
case TASK_RESPONSE_TYPE_NONE:
|
||||||
return to_json_non_oaicompat();
|
return to_json_non_oaicompat();
|
||||||
|
|
@ -582,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
|
||||||
json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
||||||
json res = json {
|
json res = json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
{"content", content},
|
||||||
{"tokens", stream ? llama_tokens {} : tokens},
|
{"tokens", tokens},
|
||||||
{"id_slot", id_slot},
|
{"id_slot", id_slot},
|
||||||
{"stop", true},
|
{"stop", true},
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
|
|
@ -619,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
|
||||||
json res = json {
|
json res = json {
|
||||||
{"choices", json::array({
|
{"choices", json::array({
|
||||||
json{
|
json{
|
||||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
{"text", content},
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"logprobs", logprobs},
|
{"logprobs", logprobs},
|
||||||
{"finish_reason", finish_reason},
|
{"finish_reason", finish_reason},
|
||||||
|
|
@ -700,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_chat_msg task_result_state::update_chat_msg(
|
||||||
|
const std::string & text_added,
|
||||||
|
bool is_partial,
|
||||||
|
std::vector<common_chat_msg_diff> & diffs) {
|
||||||
|
generated_text += text_added;
|
||||||
|
auto msg_prv_copy = chat_msg;
|
||||||
|
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||||
|
auto new_msg = common_chat_parse(
|
||||||
|
generated_text,
|
||||||
|
is_partial,
|
||||||
|
oaicompat_chat_syntax);
|
||||||
|
if (!new_msg.empty()) {
|
||||||
|
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||||
|
chat_msg = new_msg;
|
||||||
|
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
|
||||||
|
}
|
||||||
|
return chat_msg;
|
||||||
|
}
|
||||||
|
|
||||||
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
|
|
@ -956,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||||
// server_task_result_cmpl_partial
|
// server_task_result_cmpl_partial
|
||||||
//
|
//
|
||||||
json server_task_result_cmpl_partial::to_json() {
|
json server_task_result_cmpl_partial::to_json() {
|
||||||
|
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||||
switch (res_type) {
|
switch (res_type) {
|
||||||
case TASK_RESPONSE_TYPE_NONE:
|
case TASK_RESPONSE_TYPE_NONE:
|
||||||
return to_json_non_oaicompat();
|
return to_json_non_oaicompat();
|
||||||
|
|
|
||||||
|
|
@ -161,6 +161,25 @@ struct result_prompt_progress {
|
||||||
json to_json() const;
|
json to_json() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// struct for tracking the state of a task (e.g., for streaming)
|
||||||
|
struct task_result_state {
|
||||||
|
// tracking diffs for partial tool calls
|
||||||
|
std::vector<common_chat_msg_diff> diffs;
|
||||||
|
common_chat_syntax oaicompat_chat_syntax;
|
||||||
|
common_chat_msg chat_msg;
|
||||||
|
std::string generated_text; // append new chunks of generated text here
|
||||||
|
std::vector<std::string> generated_tool_call_ids;
|
||||||
|
|
||||||
|
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||||
|
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||||
|
|
||||||
|
// parse partial tool calls and update the internal state
|
||||||
|
common_chat_msg update_chat_msg(
|
||||||
|
const std::string & text_added,
|
||||||
|
bool is_partial,
|
||||||
|
std::vector<common_chat_msg_diff> & diffs);
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task_result {
|
struct server_task_result {
|
||||||
int id = -1;
|
int id = -1;
|
||||||
int id_slot = -1;
|
int id_slot = -1;
|
||||||
|
|
@ -175,6 +194,9 @@ struct server_task_result {
|
||||||
virtual int get_index() {
|
virtual int get_index() {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
virtual void update(task_result_state &) {
|
||||||
|
// only used by server_task_result_cmpl_*
|
||||||
|
}
|
||||||
virtual json to_json() = 0;
|
virtual json to_json() = 0;
|
||||||
virtual ~server_task_result() = default;
|
virtual ~server_task_result() = default;
|
||||||
};
|
};
|
||||||
|
|
@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_msg oaicompat_msg;
|
common_chat_msg oaicompat_msg; // to be populated by update()
|
||||||
|
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||||
|
bool is_updated = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
|
@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
|
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
|
|
||||||
|
virtual void update(task_result_state & state) override {
|
||||||
|
is_updated = true;
|
||||||
|
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
|
||||||
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat();
|
json to_json_non_oaicompat();
|
||||||
|
|
||||||
json to_json_oaicompat();
|
json to_json_oaicompat();
|
||||||
|
|
@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||||
|
bool is_updated = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
|
@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
|
|
||||||
|
virtual void update(task_result_state & state) override {
|
||||||
|
is_updated = true;
|
||||||
|
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||||
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat();
|
json to_json_non_oaicompat();
|
||||||
|
|
||||||
json to_json_oaicompat();
|
json to_json_oaicompat();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue