diff --git a/common/arg.cpp b/common/arg.cpp index f2aec895ba..b3663aec9f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2668,6 +2668,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.endpoint_slots = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS")); + add_opt(common_arg({ "--endpoint-exit" }, + string_format("enable POST /exit endpoint to shutdown the server (default: %s)", + params.endpoint_exit ? "enabled" : "disabled"), + [](common_params & params) { params.endpoint_exit = true; }) + .set_examples({ LLAMA_EXAMPLE_SERVER }) + .set_env("LLAMA_ARG_ENDPOINT_EXIT")); add_opt(common_arg( {"--slot-save-path"}, "PATH", "path to save slot kv cache (default: disabled)", diff --git a/common/common.h b/common/common.h index d70744840f..781b32d1c2 100644 --- a/common/common.h +++ b/common/common.h @@ -489,6 +489,7 @@ struct common_params { bool endpoint_slots = true; bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + bool endpoint_exit = false; // router server configs std::string models_dir = ""; // directory containing models for the router server diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 90898b5ec4..9e59f3ac1a 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3590,6 +3590,59 @@ void server_routes::init_routes() { res->ok(result->to_json()); return res; }; + + this->post_exit = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + + if (!params.endpoint_exit) { + SRV_WRN("%s: exit endpoint called but exit endpoint is not enabled\n", __func__); + res->error(format_error_response("Exit endpoint is disabled.", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // Check for confirmation token in request body + try { + const json body = json::parse(req.body); + const std::string confirm = json_value(body, "confirm", std::string()); + + if (confirm != "shutdown") { + res->error(format_error_response("Missing or invalid confirmation. Send {\"confirm\": \"shutdown\"}", + ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } catch (const std::exception & e) { + res->error(format_error_response("Invalid request body. Expected JSON with {\"confirm\": \"shutdown\"}", + ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + SRV_INF("%s: exit endpoint called with valid confirmation token, initiating server shutdown...\n", + __func__); + + res->ok({ + { "message", "Server shutdown initiated" }, + { "status", "terminating" } + }); + + // Schedule shutdown after response is sent. Use the explicitly provided on_shutdown callback + // if main() has set it; otherwise fall back to terminating the server queue (legacy behavior). + if (this->on_shutdown) { + auto shutdown_cb = this->on_shutdown; + std::thread([shutdown_cb]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + SRV_INF("%s: executing on_shutdown callback...\n", __func__); + try { + shutdown_cb(); + } catch (const std::exception & e) { + SRV_ERR("%s: on_shutdown callback threw: %s\n", __func__, e.what()); + } catch (...) { + SRV_ERR("%s: on_shutdown callback threw unknown exception\n", __func__); + } + }).detach(); + } + + return res; + }; } std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 230b25952e..4dbacee5c0 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -51,8 +51,8 @@ struct server_context { struct server_res_generator; struct server_routes { - server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) - : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { + server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }, std::function on_shutdown = nullptr) + : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready), on_shutdown(on_shutdown) { init_routes(); } @@ -80,6 +80,8 @@ struct server_routes { server_http_context::handler_t post_rerank; server_http_context::handler_t get_lora_adapters; server_http_context::handler_t post_lora_adapters; + server_http_context::handler_t post_exit; + private: // TODO: move these outside of server_routes? std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); @@ -90,4 +92,5 @@ private: const common_params & params; server_context_impl & ctx_server; std::function is_ready; + const std::function on_shutdown; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8538427f73..d03d3f4eca 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -113,14 +113,29 @@ int main(int argc, char ** argv, char ** envp) { return 1; } + bool is_router_server = params.model.path.empty(); + + // prepare shutdown callback depending on mode (capturing by reference is fine here I think, + // ctx_http and ctx_server live in main and outlive routes). + std::function shutdown_cb; + if (is_router_server) { + shutdown_cb = [&ctx_http]() { + ctx_http.stop(); + }; + } else { + // ctx_server declared earlier and will outlive routes + shutdown_cb = [&ctx_server]() { + ctx_server.terminate(); + }; + } + // // Router // // register API routes - server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); + server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }, shutdown_cb); - bool is_router_server = params.model.path.empty(); std::optional models_routes{}; if (is_router_server) { // setup server instances manager @@ -191,6 +206,9 @@ int main(int argc, char ** argv, char ** envp) { ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + // Exit endpoint + ctx_http.post("/exit", ex_wrapper(routes.post_exit)); + // // Start the server //