diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index be3226ada3..95b2efdbf1 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #ifdef _WIN32 #include @@ -60,7 +62,10 @@ static std::filesystem::path get_server_exec_path() { #else char path[FILENAME_MAX]; ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); - return std::filesystem::path(std::string(path, (count > 0) ? count: 0)); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); #endif } @@ -203,22 +208,27 @@ std::vector server_models::get_all_meta() { } void server_models::load(const std::string & name) { - auto meta = get_meta(name); - if (!meta.has_value()) { + std::lock_guard lk(mutex); + if (mapping.find(name) == mapping.end()) { throw std::runtime_error("model name=" + name + " is not found"); } - std::lock_guard lk(mutex); - if (meta->status != SERVER_MODEL_STATUS_FAILED && meta->status != SERVER_MODEL_STATUS_UNLOADED) { + auto meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_FAILED && meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); return; } + // prepare new instance info instance_t inst; - inst.meta = meta.value(); + inst.meta = meta; inst.meta.port = get_free_port(); inst.meta.status = SERVER_MODEL_STATUS_LOADING; + if (inst.meta.port <= 0) { + throw std::runtime_error("failed to get a port number"); + } + inst.subproc = std::make_shared(); { std::string exec_path = get_server_exec_path().string(); @@ -354,17 +364,25 @@ void server_models::wait_until_loaded(const std::string & name) { }); } -void server_models::ensure_model_loaded(const std::string & name) { +bool server_models::ensure_model_loaded(const std::string & name) { auto meta = get_meta(name); if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } if (meta->is_active()) { - return; // already loaded + return false; // already loaded } SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); load(name); wait_until_loaded(name); + { + // check final status + meta = get_meta(name); + if (!meta.has_value() || meta->status == SERVER_MODEL_STATUS_FAILED) { + throw std::runtime_error("model name=" + name + " failed to load"); + } + } + return true; } server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name) { @@ -372,7 +390,9 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - ensure_model_loaded(name); // TODO: handle failure case + if (ensure_model_loaded(name)) { + meta = get_meta(name); // refresh meta + } SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); auto proxy = std::make_unique( method, @@ -439,11 +459,11 @@ struct pipe_t { std::atomic writer_closed{false}; std::atomic reader_closed{false}; void close_write() { - writer_closed.store(true); + writer_closed.store(true, std::memory_order_relaxed); cv.notify_all(); } void close_read() { - reader_closed.store(true); + reader_closed.store(true, std::memory_order_relaxed); cv.notify_all(); } bool read(T & output, const std::function & should_stop) { diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 3cd070f89a..f8ae757fa4 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -13,7 +13,7 @@ /** * state diagram: - * + * * UNLOADED ──► LOADING ──► LOADED * ▲ │ * │ │ @@ -105,7 +105,8 @@ public: void wait_until_loaded(const std::string & name); // load the model if not loaded, otherwise do nothing - void ensure_model_loaded(const std::string & name); + // return false if model is already loaded; return true otherwise (meta may need to be refreshed) + bool ensure_model_loaded(const std::string & name); // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name);