422 lines
16 KiB
C++
422 lines
16 KiB
C++
#include "common.h"
|
|
#include "server-http.h"
|
|
#include "server-common.h"
|
|
|
|
#include <cpp-httplib/httplib.h>
|
|
|
|
#include <functional>
|
|
#include <string>
|
|
#include <thread>
|
|
|
|
// auto generated files (see README.md for details)
|
|
#include "index.html.gz.hpp"
|
|
#include "loading.html.hpp"
|
|
|
|
//
|
|
// HTTP implementation using cpp-httplib
|
|
//
|
|
|
|
class server_http_context::Impl {
|
|
public:
|
|
std::unique_ptr<httplib::Server> srv;
|
|
};
|
|
|
|
server_http_context::server_http_context()
|
|
: pimpl(std::make_unique<server_http_context::Impl>())
|
|
{}
|
|
|
|
server_http_context::~server_http_context() = default;
|
|
|
|
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
|
// skip logging requests that are regularly sent, to avoid log spam
|
|
if (req.path == "/health"
|
|
|| req.path == "/v1/health"
|
|
|| req.path == "/models"
|
|
|| req.path == "/v1/models"
|
|
|| req.path == "/props"
|
|
|| req.path == "/metrics"
|
|
) {
|
|
return;
|
|
}
|
|
|
|
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
|
|
|
SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
|
|
|
SRV_DBG("request: %s\n", req.body.c_str());
|
|
SRV_DBG("response: %s\n", res.body.c_str());
|
|
}
|
|
|
|
bool server_http_context::init(const common_params & params) {
|
|
path_prefix = params.api_prefix;
|
|
port = params.port;
|
|
hostname = params.hostname;
|
|
|
|
auto & srv = pimpl->srv;
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
|
srv.reset(
|
|
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
|
);
|
|
} else {
|
|
LOG_INF("Running without SSL\n");
|
|
srv.reset(new httplib::Server());
|
|
}
|
|
#else
|
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
LOG_ERR("Server is built without SSL support\n");
|
|
return false;
|
|
}
|
|
srv.reset(new httplib::Server());
|
|
#endif
|
|
|
|
srv->set_default_headers({{"Server", "llama.cpp"}});
|
|
srv->set_logger(log_server_request);
|
|
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
|
|
// this is fail-safe; exceptions should already handled by `ex_wrapper`
|
|
|
|
std::string message;
|
|
try {
|
|
std::rethrow_exception(ep);
|
|
} catch (const std::exception & e) {
|
|
message = e.what();
|
|
} catch (...) {
|
|
message = "Unknown Exception";
|
|
}
|
|
|
|
res.status = 500;
|
|
res.set_content(message, "text/plain");
|
|
LOG_ERR("got exception: %s\n", message.c_str());
|
|
});
|
|
|
|
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
|
if (res.status == 404) {
|
|
res.set_content(
|
|
safe_json_to_str(json {
|
|
{"error", {
|
|
{"message", "File Not Found"},
|
|
{"type", "not_found_error"},
|
|
{"code", 404}
|
|
}}
|
|
}),
|
|
"application/json; charset=utf-8"
|
|
);
|
|
}
|
|
// for other error codes, we skip processing here because it's already done by res->error()
|
|
});
|
|
|
|
// set timeouts and change hostname and port
|
|
srv->set_read_timeout (params.timeout_read);
|
|
srv->set_write_timeout(params.timeout_write);
|
|
|
|
if (params.api_keys.size() == 1) {
|
|
auto key = params.api_keys[0];
|
|
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
|
|
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
|
|
} else if (params.api_keys.size() > 1) {
|
|
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
|
|
}
|
|
|
|
srv->set_payload_max_length(1024LL * 1024LL * 1024LL); //To allow big B64 embbeding.yeah i know this is a bit hacky, but cpp-httplib doesn't provide a way to set max payload size for specific endpoints, and we need to allow big payload for image embedding endpoint. ideally we should set this limit only for that endpoint, but for now we just set it globally to 1GB which should be enough for most cases.
|
|
|
|
//
|
|
// Middlewares
|
|
//
|
|
|
|
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
|
|
static const std::unordered_set<std::string> public_endpoints = {
|
|
"/health",
|
|
"/v1/health",
|
|
"/models",
|
|
"/v1/models",
|
|
"/api/tags"
|
|
};
|
|
|
|
// If API key is not set, skip validation
|
|
if (api_keys.empty()) {
|
|
return true;
|
|
}
|
|
|
|
// If path is public or is static file, skip validation
|
|
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
|
return true;
|
|
}
|
|
|
|
// Check for API key in the Authorization header
|
|
std::string req_api_key = req.get_header_value("Authorization");
|
|
if (req_api_key.empty()) {
|
|
// retry with anthropic header
|
|
req_api_key = req.get_header_value("X-Api-Key");
|
|
}
|
|
|
|
// remove the "Bearer " prefix if needed
|
|
std::string prefix = "Bearer ";
|
|
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
|
req_api_key = req_api_key.substr(prefix.size());
|
|
}
|
|
|
|
// validate the API key
|
|
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
|
|
return true; // API key is valid
|
|
}
|
|
|
|
// API key is invalid or not provided
|
|
res.status = 401;
|
|
res.set_content(
|
|
safe_json_to_str(json {
|
|
{"error", {
|
|
{"message", "Invalid API Key"},
|
|
{"type", "authentication_error"},
|
|
{"code", 401}
|
|
}}
|
|
}),
|
|
"application/json; charset=utf-8"
|
|
);
|
|
|
|
LOG_WRN("Unauthorized: Invalid API Key\n");
|
|
|
|
return false;
|
|
};
|
|
|
|
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
|
bool ready = is_ready.load();
|
|
if (!ready) {
|
|
auto tmp = string_split<std::string>(req.path, '.');
|
|
if (req.path == "/" || tmp.back() == "html") {
|
|
res.status = 503;
|
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
|
} else {
|
|
// no endpoints is allowed to be accessed when the server is not ready
|
|
// this is to prevent any data races or inconsistent states
|
|
res.status = 503;
|
|
res.set_content(
|
|
safe_json_to_str(json {
|
|
{"error", {
|
|
{"message", "Loading model"},
|
|
{"type", "unavailable_error"},
|
|
{"code", 503}
|
|
}}
|
|
}),
|
|
"application/json; charset=utf-8"
|
|
);
|
|
}
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
// register server middlewares
|
|
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
|
if (req.method == "OPTIONS") {
|
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
|
res.set_header("Access-Control-Allow-Headers", "*");
|
|
res.set_content("", "text/html"); // blank response, no data
|
|
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
|
}
|
|
if (!middleware_server_state(req, res)) {
|
|
return httplib::Server::HandlerResponse::Handled;
|
|
}
|
|
if (!middleware_validate_api_key(req, res)) {
|
|
return httplib::Server::HandlerResponse::Handled;
|
|
}
|
|
return httplib::Server::HandlerResponse::Unhandled;
|
|
});
|
|
|
|
int n_threads_http = params.n_threads_http;
|
|
if (n_threads_http < 1) {
|
|
// +2 threads for monitoring endpoints
|
|
n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
|
}
|
|
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
|
|
srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
|
|
|
|
//
|
|
// Web UI setup
|
|
//
|
|
|
|
if (!params.webui) {
|
|
LOG_INF("Web UI is disabled\n");
|
|
} else {
|
|
// register static assets routes
|
|
if (!params.public_path.empty()) {
|
|
// Set the base directory for serving static files
|
|
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
|
|
if (!is_found) {
|
|
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
return 1;
|
|
}
|
|
} else {
|
|
// using embedded static index.html
|
|
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
|
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
|
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
|
} else {
|
|
res.set_header("Content-Encoding", "gzip");
|
|
// COEP and COOP headers, required by pyodide (python interpreter)
|
|
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
|
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
|
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
|
}
|
|
return false;
|
|
});
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool server_http_context::start() {
|
|
// Bind and listen
|
|
|
|
auto & srv = pimpl->srv;
|
|
bool was_bound = false;
|
|
bool is_sock = false;
|
|
if (string_ends_with(std::string(hostname), ".sock")) {
|
|
is_sock = true;
|
|
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
|
|
srv->set_address_family(AF_UNIX);
|
|
// bind_to_port requires a second arg, any value other than 0 should
|
|
// simply get ignored
|
|
was_bound = srv->bind_to_port(hostname, 8080);
|
|
} else {
|
|
LOG_INF("%s: binding port with default address family\n", __func__);
|
|
// bind HTTP listen port
|
|
if (port == 0) {
|
|
int bound_port = srv->bind_to_any_port(hostname);
|
|
was_bound = (bound_port >= 0);
|
|
if (was_bound) {
|
|
port = bound_port;
|
|
}
|
|
} else {
|
|
was_bound = srv->bind_to_port(hostname, port);
|
|
}
|
|
}
|
|
|
|
if (!was_bound) {
|
|
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
|
|
return false;
|
|
}
|
|
|
|
// run the HTTP server in a thread
|
|
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
|
|
srv->wait_until_ready();
|
|
|
|
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
|
|
: string_format("http://%s:%d", hostname.c_str(), port);
|
|
return true;
|
|
}
|
|
|
|
void server_http_context::stop() const {
|
|
if (pimpl->srv) {
|
|
pimpl->srv->stop();
|
|
}
|
|
}
|
|
|
|
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
|
|
for (const auto & [key, value] : headers) {
|
|
res.set_header(key, value);
|
|
}
|
|
}
|
|
|
|
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
|
|
std::map<std::string, std::string> params;
|
|
for (const auto & [key, value] : req.params) {
|
|
params[key] = value;
|
|
}
|
|
for (const auto & [key, value] : req.path_params) {
|
|
params[key] = value;
|
|
}
|
|
return params;
|
|
}
|
|
|
|
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
|
|
std::map<std::string, std::string> headers;
|
|
for (const auto & [key, value] : req.headers) {
|
|
headers[key] = value;
|
|
}
|
|
return headers;
|
|
}
|
|
|
|
static std::string build_query_string(const httplib::Request & req) {
|
|
std::string qs;
|
|
for (const auto & [key, value] : req.params) {
|
|
if (!qs.empty()) {
|
|
qs += '&';
|
|
}
|
|
qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value);
|
|
}
|
|
return qs;
|
|
}
|
|
|
|
// using unique_ptr for request to allow safe capturing in lambdas
|
|
using server_http_req_ptr = std::unique_ptr<server_http_req>;
|
|
|
|
static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
|
|
if (response->is_stream()) {
|
|
res.status = response->status;
|
|
set_headers(res, response->headers);
|
|
std::string content_type = response->content_type;
|
|
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
|
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
|
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
|
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
|
std::string chunk;
|
|
bool has_next = response->next(chunk);
|
|
if (!chunk.empty()) {
|
|
// TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
|
|
sink.write(chunk.data(), chunk.size());
|
|
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
|
|
}
|
|
if (!has_next) {
|
|
sink.done();
|
|
SRV_DBG("%s", "http: stream ended\n");
|
|
}
|
|
return has_next;
|
|
};
|
|
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
|
|
response.reset(); // trigger the destruction of the response object
|
|
request.reset(); // trigger the destruction of the request object
|
|
};
|
|
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
|
} else {
|
|
res.status = response->status;
|
|
set_headers(res, response->headers);
|
|
res.set_content(response->data, response->content_type);
|
|
}
|
|
}
|
|
|
|
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
|
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
|
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
|
get_params(req),
|
|
get_headers(req),
|
|
req.path,
|
|
build_query_string(req),
|
|
req.body,
|
|
req.is_connection_closed
|
|
});
|
|
server_http_res_ptr response = handler(*request);
|
|
process_handler_response(std::move(request), response, res);
|
|
});
|
|
}
|
|
|
|
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
|
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
|
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
|
get_params(req),
|
|
get_headers(req),
|
|
req.path,
|
|
build_query_string(req),
|
|
req.body,
|
|
req.is_connection_closed
|
|
});
|
|
server_http_res_ptr response = handler(*request);
|
|
process_handler_response(std::move(request), response, res);
|
|
});
|
|
}
|
|
|