server: prevent data race from HTTP threads (#18263)
* server: prevent data race from HTTP threads * fix params * fix default_generation_settings * nits: make handle_completions_impl looks less strange * stricter const * fix GGML_ASSERT(idx < states.size()) * move index to be managed by server_response_reader * http: make sure req & res lifecycle are tied together * fix compile * fix index handling buggy * fix data race for lora endpoint * nits: fix shadow variable * nits: revert redundant changes * nits: correct naming for json_webui_settings
This commit is contained in:
parent
3997c78e33
commit
6ce863c803
|
|
@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_cli.ctx_server.start_loop();
|
ctx_cli.ctx_server.start_loop();
|
||||||
});
|
});
|
||||||
|
|
||||||
auto inf = ctx_cli.ctx_server.get_info();
|
auto inf = ctx_cli.ctx_server.get_meta();
|
||||||
std::string modalities = "text";
|
std::string modalities = "text";
|
||||||
if (inf.has_inp_image) {
|
if (inf.has_inp_image) {
|
||||||
modalities += ", vision";
|
modalities += ", vision";
|
||||||
|
|
|
||||||
|
|
@ -115,26 +115,14 @@ bool lora_should_clear_cache(
|
||||||
!lora_all_alora(next));
|
!lora_all_alora(next));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<common_adapter_lora_info> parse_lora_request(
|
std::map<int, float> parse_lora_request(const json & data) {
|
||||||
const std::vector<common_adapter_lora_info> & lora_base,
|
std::map<int, float> lora;
|
||||||
const json & data) {
|
|
||||||
std::vector<common_adapter_lora_info> lora(lora_base);
|
|
||||||
int max_idx = lora.size();
|
|
||||||
|
|
||||||
// clear existing value
|
|
||||||
for (auto & entry : lora) {
|
|
||||||
entry.scale = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set value
|
// set value
|
||||||
for (const auto & entry : data) {
|
for (const auto & entry : data) {
|
||||||
int id = json_value(entry, "id", -1);
|
int id = json_value(entry, "id", -1);
|
||||||
float scale = json_value(entry, "scale", 0.0f);
|
float scale = json_value(entry, "scale", 0.0f);
|
||||||
if (0 <= id && id < max_idx) {
|
lora[id] = scale;
|
||||||
lora[id].scale = scale;
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid adapter id");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return lora;
|
return lora;
|
||||||
|
|
@ -1435,7 +1423,7 @@ std::string safe_json_to_str(const json & data) {
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(const llama_vocab * ctx, Iter begin, Iter end) {
|
||||||
std::string ret;
|
std::string ret;
|
||||||
for (; begin != end; ++begin) {
|
for (; begin != end; ++begin) {
|
||||||
ret += common_token_to_piece(ctx, *begin);
|
ret += common_token_to_piece(ctx, *begin);
|
||||||
|
|
@ -1445,7 +1433,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) {
|
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) {
|
||||||
return tokens_to_str(ctx, tokens.begin(), tokens.end());
|
auto model = llama_get_model(ctx);
|
||||||
|
return tokens_to_str(llama_model_get_vocab(model), tokens.begin(), tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens) {
|
||||||
|
return tokens_to_str(vocab, tokens.begin(), tokens.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// format incomplete utf-8 multibyte character for output
|
// format incomplete utf-8 multibyte character for output
|
||||||
|
|
|
||||||
|
|
@ -107,9 +107,7 @@ bool lora_should_clear_cache(
|
||||||
const std::vector<common_adapter_lora_info> & current,
|
const std::vector<common_adapter_lora_info> & current,
|
||||||
const std::vector<common_adapter_lora_info> & next);
|
const std::vector<common_adapter_lora_info> & next);
|
||||||
|
|
||||||
std::vector<common_adapter_lora_info> parse_lora_request(
|
std::map<int, float> parse_lora_request(const json & data);
|
||||||
const std::vector<common_adapter_lora_info> & lora_base,
|
|
||||||
const json & data);
|
|
||||||
|
|
||||||
bool are_lora_equal(
|
bool are_lora_equal(
|
||||||
const std::vector<common_adapter_lora_info> & l1,
|
const std::vector<common_adapter_lora_info> & l1,
|
||||||
|
|
@ -325,6 +323,7 @@ std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int i
|
||||||
std::string safe_json_to_str(const json & data);
|
std::string safe_json_to_str(const json & data);
|
||||||
|
|
||||||
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
|
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
|
||||||
|
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
|
||||||
|
|
||||||
// format incomplete utf-8 multibyte character for output
|
// format incomplete utf-8 multibyte character for output
|
||||||
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
|
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -9,11 +9,35 @@
|
||||||
|
|
||||||
struct server_context_impl; // private implementation
|
struct server_context_impl; // private implementation
|
||||||
|
|
||||||
struct server_context_info {
|
struct server_context_meta {
|
||||||
std::string build_info;
|
std::string build_info;
|
||||||
std::string model_name;
|
std::string model_name;
|
||||||
|
std::string model_path;
|
||||||
|
bool has_mtmd;
|
||||||
bool has_inp_image;
|
bool has_inp_image;
|
||||||
bool has_inp_audio;
|
bool has_inp_audio;
|
||||||
|
json json_webui_settings;
|
||||||
|
int slot_n_ctx;
|
||||||
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
// chat template
|
||||||
|
std::string chat_template;
|
||||||
|
std::string chat_template_tool_use;
|
||||||
|
|
||||||
|
// tokens
|
||||||
|
std::string bos_token_str;
|
||||||
|
std::string eos_token_str;
|
||||||
|
llama_token fim_pre_token;
|
||||||
|
llama_token fim_sub_token;
|
||||||
|
llama_token fim_mid_token;
|
||||||
|
|
||||||
|
// model meta
|
||||||
|
enum llama_vocab_type model_vocab_type;
|
||||||
|
int32_t model_vocab_n_tokens;
|
||||||
|
int32_t model_n_ctx_train;
|
||||||
|
int32_t model_n_embd_inp;
|
||||||
|
uint64_t model_n_params;
|
||||||
|
uint64_t model_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_context {
|
struct server_context {
|
||||||
|
|
@ -33,14 +57,15 @@ struct server_context {
|
||||||
void terminate();
|
void terminate();
|
||||||
|
|
||||||
// get the underlaying llama_context, can return nullptr if sleeping
|
// get the underlaying llama_context, can return nullptr if sleeping
|
||||||
|
// not thread-safe, should only be used from the main thread
|
||||||
llama_context * get_llama_context() const;
|
llama_context * get_llama_context() const;
|
||||||
|
|
||||||
// get a new response reader, used by CLI application
|
// get a new response reader, used by CLI application
|
||||||
server_response_reader get_response_reader();
|
server_response_reader get_response_reader();
|
||||||
|
|
||||||
// get server info
|
// get server metadata (read-only), can only be called after load_model()
|
||||||
// used by CLI application
|
// not thread-safe, should only be used from the main thread
|
||||||
server_context_info get_info() const;
|
server_context_meta get_meta() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,13 +73,17 @@ struct server_context {
|
||||||
struct server_res_generator;
|
struct server_res_generator;
|
||||||
|
|
||||||
struct server_routes {
|
struct server_routes {
|
||||||
server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
|
server_routes(const common_params & params, server_context & ctx_server);
|
||||||
: params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
|
|
||||||
init_routes();
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_routes();
|
void init_routes();
|
||||||
|
|
||||||
|
// note: this is not thread-safe and can only when ctx_http.is_ready is false
|
||||||
|
void update_meta(const server_context & ctx_server) {
|
||||||
|
this->meta = std::make_unique<server_context_meta>(ctx_server.get_meta());
|
||||||
|
}
|
||||||
|
|
||||||
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
||||||
|
// they won't be called until ctx_http.is_ready is set to true
|
||||||
server_http_context::handler_t get_health;
|
server_http_context::handler_t get_health;
|
||||||
server_http_context::handler_t get_metrics;
|
server_http_context::handler_t get_metrics;
|
||||||
server_http_context::handler_t get_slots;
|
server_http_context::handler_t get_slots;
|
||||||
|
|
@ -78,13 +107,24 @@ struct server_routes {
|
||||||
server_http_context::handler_t get_lora_adapters;
|
server_http_context::handler_t get_lora_adapters;
|
||||||
server_http_context::handler_t post_lora_adapters;
|
server_http_context::handler_t post_lora_adapters;
|
||||||
private:
|
private:
|
||||||
// TODO: move these outside of server_routes?
|
std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||||
|
const server_http_req & req,
|
||||||
|
server_task_type type,
|
||||||
|
const json & data,
|
||||||
|
const std::vector<raw_buffer> & files,
|
||||||
|
task_response_type res_type);
|
||||||
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
|
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
|
||||||
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
|
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
|
||||||
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
|
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
|
||||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
|
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
|
||||||
|
|
||||||
|
// using unique_ptr to allow late initialization of const
|
||||||
|
std::unique_ptr<const server_context_meta> meta;
|
||||||
|
|
||||||
const common_params & params;
|
const common_params & params;
|
||||||
server_context_impl & ctx_server;
|
const server_context_impl & ctx_server;
|
||||||
std::function<bool()> is_ready;
|
|
||||||
|
server_queue & queue_tasks;
|
||||||
|
server_response & queue_results;
|
||||||
|
std::unique_ptr<server_res_generator> create_response(bool bypass_sleep = false);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -177,12 +177,11 @@ bool server_http_context::init(const common_params & params) {
|
||||||
if (!ready) {
|
if (!ready) {
|
||||||
auto tmp = string_split<std::string>(req.path, '.');
|
auto tmp = string_split<std::string>(req.path, '.');
|
||||||
if (req.path == "/" || tmp.back() == "html") {
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
// allow the models endpoint to be accessed during loading
|
|
||||||
return true;
|
|
||||||
} else {
|
} else {
|
||||||
|
// no endpoints is allowed to be accessed when the server is not ready
|
||||||
|
// this is to prevent any data races or inconsistent states
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
res.set_content(
|
res.set_content(
|
||||||
safe_json_to_str(json {
|
safe_json_to_str(json {
|
||||||
|
|
@ -334,12 +333,16 @@ static std::map<std::string, std::string> get_headers(const httplib::Request & r
|
||||||
return headers;
|
return headers;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
|
// using unique_ptr for request to allow safe capturing in lambdas
|
||||||
|
using server_http_req_ptr = std::unique_ptr<server_http_req>;
|
||||||
|
|
||||||
|
static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
|
||||||
if (response->is_stream()) {
|
if (response->is_stream()) {
|
||||||
res.status = response->status;
|
res.status = response->status;
|
||||||
set_headers(res, response->headers);
|
set_headers(res, response->headers);
|
||||||
std::string content_type = response->content_type;
|
std::string content_type = response->content_type;
|
||||||
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||||
|
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
||||||
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||||
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||||
std::string chunk;
|
std::string chunk;
|
||||||
|
|
@ -355,8 +358,9 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re
|
||||||
}
|
}
|
||||||
return has_next;
|
return has_next;
|
||||||
};
|
};
|
||||||
const auto on_complete = [response = r_ptr](bool) mutable {
|
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
|
||||||
response.reset(); // trigger the destruction of the response object
|
response.reset(); // trigger the destruction of the response object
|
||||||
|
request.reset(); // trigger the destruction of the request object
|
||||||
};
|
};
|
||||||
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -368,27 +372,29 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re
|
||||||
|
|
||||||
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||||
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_http_res_ptr response = handler(server_http_req{
|
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||||
get_params(req),
|
get_params(req),
|
||||||
get_headers(req),
|
get_headers(req),
|
||||||
req.path,
|
req.path,
|
||||||
req.body,
|
req.body,
|
||||||
req.is_connection_closed
|
req.is_connection_closed
|
||||||
});
|
});
|
||||||
process_handler_response(response, res);
|
server_http_res_ptr response = handler(*request);
|
||||||
|
process_handler_response(std::move(request), response, res);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_http_res_ptr response = handler(server_http_req{
|
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||||
get_params(req),
|
get_params(req),
|
||||||
get_headers(req),
|
get_headers(req),
|
||||||
req.path,
|
req.path,
|
||||||
req.body,
|
req.body,
|
||||||
req.is_connection_closed
|
req.is_connection_closed
|
||||||
});
|
});
|
||||||
process_handler_response(response, res);
|
server_http_res_ptr response = handler(*request);
|
||||||
|
process_handler_response(std::move(request), response, res);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -325,23 +325,25 @@ void server_response::terminate() {
|
||||||
// server_response_reader
|
// server_response_reader
|
||||||
//
|
//
|
||||||
|
|
||||||
void server_response_reader::post_task(server_task && task) {
|
void server_response_reader::post_task(server_task && task, bool front) {
|
||||||
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
||||||
|
task.index = 0;
|
||||||
id_tasks.insert(task.id);
|
id_tasks.insert(task.id);
|
||||||
states.push_back(task.create_state());
|
states.push_back(task.create_state());
|
||||||
queue_results.add_waiting_task_id(task.id);
|
queue_results.add_waiting_task_id(task.id);
|
||||||
queue_tasks.post(std::move(task));
|
queue_tasks.post(std::move(task), front);
|
||||||
}
|
}
|
||||||
|
|
||||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
|
||||||
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
||||||
id_tasks = server_task::get_list_id(tasks);
|
id_tasks = server_task::get_list_id(tasks);
|
||||||
states.reserve(tasks.size());
|
states.reserve(tasks.size());
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
tasks[i].index = i;
|
||||||
states.push_back(tasks[i].create_state());
|
states.push_back(tasks[i].create_state());
|
||||||
}
|
}
|
||||||
queue_results.add_waiting_tasks(tasks);
|
queue_results.add_waiting_tasks(tasks);
|
||||||
queue_tasks.post(std::move(tasks));
|
queue_tasks.post(std::move(tasks), front);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool server_response_reader::has_next() const {
|
bool server_response_reader::has_next() const {
|
||||||
|
|
@ -367,7 +369,7 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
|
||||||
}
|
}
|
||||||
if (!states.empty()) {
|
if (!states.empty()) {
|
||||||
// update the generation state if needed
|
// update the generation state if needed
|
||||||
size_t idx = result->get_index();
|
const size_t idx = result->index;
|
||||||
GGML_ASSERT(idx < states.size());
|
GGML_ASSERT(idx < states.size());
|
||||||
result->update(states[idx]);
|
result->update(states[idx]);
|
||||||
}
|
}
|
||||||
|
|
@ -383,6 +385,7 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
|
||||||
|
|
||||||
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
|
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
|
||||||
batch_response batch_res;
|
batch_response batch_res;
|
||||||
|
batch_res.results.clear();
|
||||||
batch_res.results.resize(id_tasks.size());
|
batch_res.results.resize(id_tasks.size());
|
||||||
while (has_next()) {
|
while (has_next()) {
|
||||||
auto res = next(should_stop);
|
auto res = next(should_stop);
|
||||||
|
|
@ -394,7 +397,7 @@ server_response_reader::batch_response server_response_reader::wait_for_all(cons
|
||||||
batch_res.error = std::move(res);
|
batch_res.error = std::move(res);
|
||||||
return batch_res;
|
return batch_res;
|
||||||
}
|
}
|
||||||
const size_t idx = res->get_index();
|
const size_t idx = res->index;
|
||||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||||
batch_res.results[idx] = std::move(res);
|
batch_res.results[idx] = std::move(res);
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <vector>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
// struct for managing server tasks
|
// struct for managing server tasks
|
||||||
|
|
@ -173,8 +174,10 @@ struct server_response_reader {
|
||||||
int get_new_id() {
|
int get_new_id() {
|
||||||
return queue_tasks.get_new_id();
|
return queue_tasks.get_new_id();
|
||||||
}
|
}
|
||||||
void post_task(server_task && task);
|
|
||||||
void post_tasks(std::vector<server_task> && tasks);
|
// if front = true, the task will be posted to the front of the queue (high priority)
|
||||||
|
void post_task(server_task && task, bool front = false);
|
||||||
|
void post_tasks(std::vector<server_task> && tasks, bool front = false);
|
||||||
bool has_next() const;
|
bool has_next() const;
|
||||||
|
|
||||||
// return nullptr if should_stop() is true before receiving a result
|
// return nullptr if should_stop() is true before receiving a result
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ json task_params::to_json(bool only_metrics) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
json lora = json::array();
|
json lora = json::array();
|
||||||
for (size_t i = 0; i < this->lora.size(); ++i) {
|
for (auto & it : this->lora) {
|
||||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
lora.push_back({{"id", it.first}, {"scale", it.second}});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (only_metrics) {
|
if (only_metrics) {
|
||||||
|
|
@ -145,12 +145,10 @@ json task_params::to_json(bool only_metrics) const {
|
||||||
//
|
//
|
||||||
|
|
||||||
task_params server_task::params_from_json_cmpl(
|
task_params server_task::params_from_json_cmpl(
|
||||||
const llama_context * ctx,
|
const llama_vocab * vocab,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
|
const int n_ctx_slot,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
const llama_model * model = llama_get_model(ctx);
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
task_params params;
|
task_params params;
|
||||||
|
|
||||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
|
||||||
|
|
@ -223,12 +221,12 @@ task_params server_task::params_from_json_cmpl(
|
||||||
|
|
||||||
if (data.contains("lora")) {
|
if (data.contains("lora")) {
|
||||||
if (data.at("lora").is_array()) {
|
if (data.at("lora").is_array()) {
|
||||||
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
|
params.lora = parse_lora_request(data.at("lora"));
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
params.lora = params_base.lora_adapters;
|
params.lora = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add more sanity checks for the input parameters
|
// TODO: add more sanity checks for the input parameters
|
||||||
|
|
@ -243,11 +241,11 @@ task_params server_task::params_from_json_cmpl(
|
||||||
|
|
||||||
if (params.sampling.penalty_last_n == -1) {
|
if (params.sampling.penalty_last_n == -1) {
|
||||||
// note: should be the slot's context and not the full context, but it's ok
|
// note: should be the slot's context and not the full context, but it's ok
|
||||||
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
params.sampling.penalty_last_n = n_ctx_slot;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.dry_penalty_last_n == -1) {
|
if (params.sampling.dry_penalty_last_n == -1) {
|
||||||
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
params.sampling.dry_penalty_last_n = n_ctx_slot;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.dry_base < 1.0f) {
|
if (params.sampling.dry_base < 1.0f) {
|
||||||
|
|
@ -1324,6 +1322,30 @@ json server_task_result_slot_erase::to_json() {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// server_task_result_get_lora
|
||||||
|
//
|
||||||
|
|
||||||
|
json server_task_result_get_lora::to_json() {
|
||||||
|
json result = json::array();
|
||||||
|
for (size_t i = 0; i < loras.size(); ++i) {
|
||||||
|
auto & lora = loras[i];
|
||||||
|
json entry = {
|
||||||
|
{"id", i},
|
||||||
|
{"path", lora.info.path},
|
||||||
|
{"scale", lora.info.scale},
|
||||||
|
{"task_name", lora.info.task_name},
|
||||||
|
{"prompt_prefix", lora.info.prompt_prefix},
|
||||||
|
};
|
||||||
|
if (!lora.alora_invocation_tokens.empty()) {
|
||||||
|
entry["alora_invocation_string"] = lora.alora_invocation_string;
|
||||||
|
entry["alora_invocation_tokens"] = lora.alora_invocation_tokens;
|
||||||
|
}
|
||||||
|
result.push_back(std::move(entry));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// server_task_result_apply_lora
|
// server_task_result_apply_lora
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
// TODO: prevent including the whole server-common.h as we only use server_tokens
|
// TODO: prevent including the whole server-common.h as we only use server_tokens
|
||||||
#include "server-common.h"
|
#include "server-common.h"
|
||||||
|
|
@ -23,6 +24,7 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SLOT_SAVE,
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||||
|
SERVER_TASK_TYPE_GET_LORA,
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -60,7 +62,7 @@ struct task_params {
|
||||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
std::vector<common_adapter_lora_info> lora;
|
std::map<int, float> lora; // mapping adapter ID -> scale
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
std::vector<std::string> response_fields;
|
std::vector<std::string> response_fields;
|
||||||
|
|
@ -105,8 +107,10 @@ struct task_result_state {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
int index = -1; // used when there are multiple prompts (batch request)
|
|
||||||
|
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
|
||||||
|
size_t index = 0; // used when there are multiple prompts (batch request)
|
||||||
|
|
||||||
// used by SERVER_TASK_TYPE_CANCEL
|
// used by SERVER_TASK_TYPE_CANCEL
|
||||||
int id_target = -1;
|
int id_target = -1;
|
||||||
|
|
@ -138,7 +142,7 @@ struct server_task {
|
||||||
bool metrics_reset_bucket = false;
|
bool metrics_reset_bucket = false;
|
||||||
|
|
||||||
// used by SERVER_TASK_TYPE_SET_LORA
|
// used by SERVER_TASK_TYPE_SET_LORA
|
||||||
std::vector<common_adapter_lora_info> set_lora;
|
std::map<int, float> set_lora; // mapping adapter ID -> scale
|
||||||
|
|
||||||
server_task() = default;
|
server_task() = default;
|
||||||
|
|
||||||
|
|
@ -149,9 +153,10 @@ struct server_task {
|
||||||
}
|
}
|
||||||
|
|
||||||
static task_params params_from_json_cmpl(
|
static task_params params_from_json_cmpl(
|
||||||
const llama_context * ctx,
|
const llama_vocab * vocab,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
const json & data);
|
const int n_ctx_slot,
|
||||||
|
const json & data);
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||||
|
|
@ -162,10 +167,9 @@ struct server_task {
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
server_task create_child(int id_parent, int id_child, int idx) const {
|
server_task create_child(int id_parent, int id_child) const {
|
||||||
server_task copy;
|
server_task copy;
|
||||||
copy.id = id_child;
|
copy.id = id_child;
|
||||||
copy.index = idx;
|
|
||||||
copy.id_parent = id_parent;
|
copy.id_parent = id_parent;
|
||||||
copy.params = params;
|
copy.params = params;
|
||||||
copy.type = type;
|
copy.type = type;
|
||||||
|
|
@ -212,6 +216,10 @@ struct result_prompt_progress {
|
||||||
struct server_task_result {
|
struct server_task_result {
|
||||||
int id = -1;
|
int id = -1;
|
||||||
int id_slot = -1;
|
int id_slot = -1;
|
||||||
|
|
||||||
|
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
|
||||||
|
size_t index = 0; // to be used for batched tasks
|
||||||
|
|
||||||
virtual bool is_error() {
|
virtual bool is_error() {
|
||||||
// only used by server_task_result_error
|
// only used by server_task_result_error
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -220,9 +228,6 @@ struct server_task_result {
|
||||||
// only used by server_task_result_cmpl_*
|
// only used by server_task_result_cmpl_*
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
virtual int get_index() {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
virtual void update(task_result_state &) {
|
virtual void update(task_result_state &) {
|
||||||
// only used by server_task_result_cmpl_*
|
// only used by server_task_result_cmpl_*
|
||||||
}
|
}
|
||||||
|
|
@ -255,8 +260,6 @@ struct completion_token_output {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_final : server_task_result {
|
struct server_task_result_cmpl_final : server_task_result {
|
||||||
int index = 0;
|
|
||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
llama_tokens tokens;
|
llama_tokens tokens;
|
||||||
|
|
||||||
|
|
@ -289,10 +292,6 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||||
bool is_updated = false;
|
bool is_updated = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool is_stop() override {
|
virtual bool is_stop() override {
|
||||||
return true; // in stream mode, final responses are considered stop
|
return true; // in stream mode, final responses are considered stop
|
||||||
}
|
}
|
||||||
|
|
@ -318,8 +317,6 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_partial : server_task_result {
|
struct server_task_result_cmpl_partial : server_task_result {
|
||||||
int index = 0;
|
|
||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
llama_tokens tokens;
|
llama_tokens tokens;
|
||||||
|
|
||||||
|
|
@ -340,10 +337,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||||
bool is_updated = false;
|
bool is_updated = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool is_stop() override {
|
virtual bool is_stop() override {
|
||||||
return false; // in stream mode, partial responses are not considered stop
|
return false; // in stream mode, partial responses are not considered stop
|
||||||
}
|
}
|
||||||
|
|
@ -365,7 +358,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_embd : server_task_result {
|
struct server_task_result_embd : server_task_result {
|
||||||
int index = 0;
|
|
||||||
std::vector<std::vector<float>> embedding;
|
std::vector<std::vector<float>> embedding;
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
@ -373,10 +365,6 @@ struct server_task_result_embd : server_task_result {
|
||||||
// response formatting
|
// response formatting
|
||||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||||
|
|
||||||
virtual int get_index() override {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
|
|
||||||
json to_json_non_oaicompat();
|
json to_json_non_oaicompat();
|
||||||
|
|
@ -385,20 +373,14 @@ struct server_task_result_embd : server_task_result {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_rerank : server_task_result {
|
struct server_task_result_rerank : server_task_result {
|
||||||
int index = 0;
|
|
||||||
float score = -1e6;
|
float score = -1e6;
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
virtual int get_index() override {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_error : server_task_result {
|
struct server_task_result_error : server_task_result {
|
||||||
int index = 0;
|
|
||||||
error_type err_type = ERROR_TYPE_SERVER;
|
error_type err_type = ERROR_TYPE_SERVER;
|
||||||
std::string err_msg;
|
std::string err_msg;
|
||||||
|
|
||||||
|
|
@ -460,6 +442,17 @@ struct server_task_result_slot_erase : server_task_result {
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct server_task_result_get_lora : server_task_result {
|
||||||
|
struct lora {
|
||||||
|
common_adapter_lora_info info;
|
||||||
|
std::string alora_invocation_string;
|
||||||
|
llama_tokens alora_invocation_tokens;
|
||||||
|
};
|
||||||
|
std::vector<lora> loras;
|
||||||
|
|
||||||
|
virtual json to_json() override;
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task_result_apply_lora : server_task_result {
|
struct server_task_result_apply_lora : server_task_result {
|
||||||
virtual json to_json() override;
|
virtual json to_json() override;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ int main(int argc, char ** argv, char ** envp) {
|
||||||
//
|
//
|
||||||
|
|
||||||
// register API routes
|
// register API routes
|
||||||
server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); });
|
server_routes routes(params, ctx_server);
|
||||||
|
|
||||||
bool is_router_server = params.model.path.empty();
|
bool is_router_server = params.model.path.empty();
|
||||||
std::optional<server_models_routes> models_routes{};
|
std::optional<server_models_routes> models_routes{};
|
||||||
|
|
@ -252,6 +252,7 @@ int main(int argc, char ** argv, char ** envp) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routes.update_meta(ctx_server);
|
||||||
ctx_http.is_ready.store(true);
|
ctx_http.is_ready.store(true);
|
||||||
|
|
||||||
LOG_INF("%s: model loaded\n", __func__);
|
LOG_INF("%s: model loaded\n", __func__);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue