server: move server-context to its own cpp|h (#17595)
* git mv * add server-context.h * add server-context.h * clean up headers * cont : cleanup * also expose server_response_reader (to be used by CLI) * fix windows build * decouple server_routes and server_http --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
8c32d9d96d
commit
ab49f094d2
|
|
@ -21,6 +21,8 @@ set(TARGET_SRCS
|
|||
server-queue.h
|
||||
server-common.cpp
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,83 @@
|
|||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
struct server_context_impl; // private implementation
|
||||
|
||||
struct server_context {
|
||||
std::unique_ptr<server_context_impl> impl;
|
||||
|
||||
server_context();
|
||||
~server_context();
|
||||
|
||||
// initialize slots and server-related data
|
||||
void init();
|
||||
|
||||
// load the model and initialize llama_context
|
||||
// returns true on success
|
||||
bool load_model(const common_params & params);
|
||||
|
||||
// this function will block main thread until termination
|
||||
void start_loop();
|
||||
|
||||
// terminate main loop (will unblock start_loop)
|
||||
void terminate();
|
||||
|
||||
// get the underlaying llama_context
|
||||
llama_context * get_llama_context() const;
|
||||
|
||||
// get the underlaying queue_tasks and queue_results
|
||||
// used by CLI application
|
||||
std::pair<server_queue &, server_response &> get_queues();
|
||||
};
|
||||
|
||||
|
||||
// forward declarations
|
||||
struct server_res_generator;
|
||||
|
||||
struct server_routes {
|
||||
server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
|
||||
: params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
|
||||
init_routes();
|
||||
}
|
||||
|
||||
void init_routes();
|
||||
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
||||
server_http_context::handler_t get_health;
|
||||
server_http_context::handler_t get_metrics;
|
||||
server_http_context::handler_t get_slots;
|
||||
server_http_context::handler_t post_slots;
|
||||
server_http_context::handler_t get_props;
|
||||
server_http_context::handler_t post_props;
|
||||
server_http_context::handler_t get_api_show;
|
||||
server_http_context::handler_t post_infill;
|
||||
server_http_context::handler_t post_completions;
|
||||
server_http_context::handler_t post_completions_oai;
|
||||
server_http_context::handler_t post_chat_completions;
|
||||
server_http_context::handler_t post_anthropic_messages;
|
||||
server_http_context::handler_t post_anthropic_count_tokens;
|
||||
server_http_context::handler_t post_apply_template;
|
||||
server_http_context::handler_t get_models;
|
||||
server_http_context::handler_t post_tokenize;
|
||||
server_http_context::handler_t post_detokenize;
|
||||
server_http_context::handler_t post_embeddings;
|
||||
server_http_context::handler_t post_embeddings_oai;
|
||||
server_http_context::handler_t post_rerank;
|
||||
server_http_context::handler_t get_lora_adapters;
|
||||
server_http_context::handler_t post_lora_adapters;
|
||||
private:
|
||||
// TODO: move these outside of server_routes?
|
||||
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
|
||||
|
||||
const common_params & params;
|
||||
server_context_impl & ctx_server;
|
||||
std::function<bool()> is_ready;
|
||||
};
|
||||
|
|
@ -266,3 +266,86 @@ void server_response::terminate() {
|
|||
running = false;
|
||||
condition_results.notify_all();
|
||||
}
|
||||
|
||||
//
|
||||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
queue_results.add_waiting_tasks(tasks);
|
||||
queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
bool server_response_reader::has_next() const {
|
||||
return !cancelled && received_count < id_tasks.size();
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
|
||||
while (true) {
|
||||
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
if (result->is_error()) {
|
||||
stop(); // cancel remaining tasks
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// should not reach here
|
||||
}
|
||||
|
||||
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
|
||||
batch_response batch_res;
|
||||
batch_res.results.resize(id_tasks.size());
|
||||
while (has_next()) {
|
||||
auto res = next(should_stop);
|
||||
if (res == nullptr) {
|
||||
batch_res.is_terminated = true;
|
||||
return batch_res;
|
||||
}
|
||||
if (res->is_error()) {
|
||||
batch_res.error = std::move(res);
|
||||
return batch_res;
|
||||
}
|
||||
const size_t idx = res->get_index();
|
||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||
batch_res.results[idx] = std::move(res);
|
||||
}
|
||||
return batch_res;
|
||||
}
|
||||
|
||||
void server_response_reader::stop() {
|
||||
queue_results.remove_waiting_task_ids(id_tasks);
|
||||
if (has_next() && !cancelled) {
|
||||
// if tasks is not finished yet, cancel them
|
||||
cancelled = true;
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto & id_task : id_tasks) {
|
||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(std::move(cancel_tasks), true);
|
||||
} else {
|
||||
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,3 +108,39 @@ public:
|
|||
// terminate the waiting loop
|
||||
void terminate();
|
||||
};
|
||||
|
||||
// utility class to make working with server_queue and server_response easier
|
||||
// it provides a generator-like API for server responses
|
||||
// support pooling connection state and aggregating multiple results
|
||||
struct server_response_reader {
|
||||
std::unordered_set<int> id_tasks;
|
||||
server_queue & queue_tasks;
|
||||
server_response & queue_results;
|
||||
size_t received_count = 0;
|
||||
bool cancelled = false;
|
||||
int polling_interval_seconds;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void post_tasks(std::vector<server_task> && tasks);
|
||||
bool has_next() const;
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()> & should_stop);
|
||||
|
||||
struct batch_response {
|
||||
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||
std::vector<server_task_result_ptr> results;
|
||||
server_task_result_ptr error; // nullptr if no error
|
||||
};
|
||||
// aggregate multiple results
|
||||
batch_response wait_for_all(const std::function<bool()> & should_stop);
|
||||
|
||||
void stop();
|
||||
};
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue