server: support multiple generations from one prompt (OAI "n" option) (#17775)
* backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
This commit is contained in:
parent
09c7c50e64
commit
c42712b056
|
|
@ -493,6 +493,8 @@ Note for `multimodal_data` in JSON object prompts. This should be an array of st
|
|||
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
|
||||
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
|
||||
|
||||
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
|
||||
|
||||
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
|
||||
|
||||
`stop`: Specify a JSON array of stopping strings.
|
||||
|
|
|
|||
|
|
@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
|
|||
return 0;
|
||||
}
|
||||
|
||||
server_tokens server_tokens::clone() const {
|
||||
server_tokens res;
|
||||
res.has_mtmd = has_mtmd;
|
||||
res.tokens = tokens;
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
size_t idx = it->first;
|
||||
const mtmd::input_chunk_ptr & chunk = it->second;
|
||||
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
|
@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
|
|||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::runtime_error("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "echo" field
|
||||
if (json_value(body, "echo", false)) {
|
||||
throw std::runtime_error("Only no echo is supported");
|
||||
|
|
@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::invalid_argument("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
|
|
|
|||
|
|
@ -215,6 +215,8 @@ public:
|
|||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
|||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
|
@ -254,6 +255,15 @@ struct server_slot {
|
|||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
}
|
||||
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (is_processing()) {
|
||||
GGML_ASSERT(task);
|
||||
|
|
@ -383,6 +393,17 @@ struct server_slot {
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
other.prompt = prompt.clone();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
@ -1022,7 +1043,9 @@ struct server_context_impl {
|
|||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
slot.state = slot.is_child()
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
|
@ -1684,6 +1707,12 @@ struct server_context_impl {
|
|||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
if (slot.is_parent() || slot.is_child()) {
|
||||
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Shift context
|
||||
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
||||
|
||||
|
|
@ -2308,6 +2337,26 @@ struct server_context_impl {
|
|||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
|
|
@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.index = idx++;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
|
|
@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
|
|
@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
GGML_ASSERT(!arr.empty() && "empty results");
|
||||
if (arr.size() == 1) {
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr[0]);
|
||||
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
|
||||
// if multiple results in OAI format, we need to re-format them
|
||||
json & choices = arr[0]["choices"];
|
||||
for (size_t i = 1; i < arr.size(); i++) {
|
||||
choices.push_back(std::move(arr[i]["choices"][0]));
|
||||
}
|
||||
res->ok(arr[0]);
|
||||
} else {
|
||||
// multi-results, non-OAI compat
|
||||
res->ok(arr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
|
@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
if (params.n_cmpl > params_base.n_parallel) {
|
||||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
|||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
|
|
@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ struct task_params {
|
|||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
|
@ -89,6 +90,10 @@ struct server_task {
|
|||
int id_target = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
size_t n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
server_tokens tokens;
|
||||
|
|
@ -130,6 +135,17 @@ struct server_task {
|
|||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
server_task create_child(int id_parent, int id_child, int idx) const {
|
||||
server_task copy;
|
||||
copy.id = id_child;
|
||||
copy.index = idx;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
|
|
@ -466,6 +482,14 @@ struct server_prompt {
|
|||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
server_prompt clone() const {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
|
|
|
|||
|
|
@ -477,3 +477,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|||
assert last_progress["total"] > 0
|
||||
assert last_progress["processed"] == last_progress["total"]
|
||||
assert total_batch_count == batch_count
|
||||
|
||||
|
||||
def test_chat_completions_multiple_choices():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"n": 2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
|
|
|||
Loading…
Reference in New Issue