#include "server-task.h" #include "server-queue.h" #include "log.h" #include #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) // // server_queue // int server_queue::post(server_task && task, bool front) { std::unique_lock lock(mutex_tasks); GGML_ASSERT(task.id != -1); // if this is cancel task make sure to clean up pending tasks if (task.type == SERVER_TASK_TYPE_CANCEL) { cleanup_pending_task(task.id_target); } const int task_id = task.id; QUE_DBG("new task, id = %d, front = %d\n", task_id, front); if (front) { queue_tasks.push_front(std::move(task)); } else { queue_tasks.push_back(std::move(task)); } condition_tasks.notify_one(); return task_id; } int server_queue::post(std::vector && tasks, bool front) { std::unique_lock lock(mutex_tasks); for (auto & task : tasks) { if (task.id == -1) { task.id = id++; } // if this is cancel task make sure to clean up pending tasks if (task.type == SERVER_TASK_TYPE_CANCEL) { cleanup_pending_task(task.id_target); } QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); if (front) { queue_tasks.push_front(std::move(task)); } else { queue_tasks.push_back(std::move(task)); } } condition_tasks.notify_one(); return 0; } void server_queue::defer(server_task && task) { std::unique_lock lock(mutex_tasks); QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); condition_tasks.notify_one(); } int server_queue::get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; return new_id; } void server_queue::on_new_task(std::function callback) { callback_new_task = std::move(callback); } void server_queue::on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } void server_queue::pop_deferred_task() { std::unique_lock lock(mutex_tasks); if (!queue_tasks_deferred.empty()) { queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); queue_tasks_deferred.pop_front(); } condition_tasks.notify_one(); } void server_queue::terminate() { std::unique_lock lock(mutex_tasks); running = false; condition_tasks.notify_all(); } void server_queue::start_loop() { running = true; while (true) { QUE_DBG("%s", "processing new tasks\n"); while (true) { std::unique_lock lock(mutex_tasks); if (!running) { QUE_DBG("%s", "terminate\n"); return; } if (queue_tasks.empty()) { lock.unlock(); break; } server_task task = std::move(queue_tasks.front()); queue_tasks.pop_front(); lock.unlock(); QUE_DBG("processing task, id = %d\n", task.id); callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready QUE_DBG("%s", "update slots\n"); callback_update_slots(); QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); if (!running) { QUE_DBG("%s", "terminate\n"); return; } if (queue_tasks.empty()) { condition_tasks.wait(lock, [&]{ return (!queue_tasks.empty() || !running); }); } } } } void server_queue::cleanup_pending_task(int id_target) { // no need lock because this is called exclusively by post() auto rm_func = [id_target](const server_task & task) { return task.id == id_target; }; queue_tasks.erase( std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), queue_tasks.end()); queue_tasks_deferred.erase( std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), queue_tasks_deferred.end()); } // // server_response // void server_response::add_waiting_task_id(int id_task) { RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.insert(id_task); } void server_response::add_waiting_tasks(const std::vector & tasks) { std::unique_lock lock(mutex_results); for (const auto & task : tasks) { RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); waiting_task_ids.insert(task.id); } } void server_response::remove_waiting_task_id(int id_task) { RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.erase(id_task); // make sure to clean up all pending results queue_results.erase( std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { return res->id == id_task; }), queue_results.end()); } void server_response::remove_waiting_task_ids(const std::unordered_set & id_tasks) { std::unique_lock lock(mutex_results); for (const auto & id_task : id_tasks) { RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); waiting_task_ids.erase(id_task); } } server_task_result_ptr server_response::recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ if (!running) { RES_DBG("%s : queue result stop\n", "recv"); std::terminate(); // we cannot return here since the caller is HTTP code } return !queue_results.empty(); }); for (size_t i = 0; i < queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } } } // should never reach here } server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { while (true) { std::unique_lock lock(mutex_results); for (int i = 0; i < (int) queue_results.size(); i++) { if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } } std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); if (!running) { RES_DBG("%s : queue result stop\n", __func__); std::terminate(); // we cannot return here since the caller is HTTP code } if (cr_res == std::cv_status::timeout) { return nullptr; } } // should never reach here } server_task_result_ptr server_response::recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } void server_response::send(server_task_result_ptr && result) { RES_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { if (result->id == id_task) { RES_DBG("task id = %d pushed to result queue\n", result->id); queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } } } void server_response::terminate() { running = false; condition_results.notify_all(); } // // server_response_reader // void server_response_reader::post_tasks(std::vector && tasks) { id_tasks = server_task::get_list_id(tasks); queue_results.add_waiting_tasks(tasks); queue_tasks.post(std::move(tasks)); } bool server_response_reader::has_next() const { return !cancelled && received_count < id_tasks.size(); } // return nullptr if should_stop() is true before receiving a result // note: if one error is received, it will stop further processing and return error result server_task_result_ptr server_response_reader::next(const std::function & should_stop) { while (true) { server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); if (result == nullptr) { // timeout, check stop condition if (should_stop()) { SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); return nullptr; } } else { if (result->is_error()) { stop(); // cancel remaining tasks SRV_DBG("%s", "received error result, stopping further processing\n"); return result; } if (result->is_stop()) { received_count++; } return result; } } // should not reach here } server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { batch_response batch_res; batch_res.results.resize(id_tasks.size()); while (has_next()) { auto res = next(should_stop); if (res == nullptr) { batch_res.is_terminated = true; return batch_res; } if (res->is_error()) { batch_res.error = std::move(res); return batch_res; } const size_t idx = res->get_index(); GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); batch_res.results[idx] = std::move(res); } return batch_res; } void server_response_reader::stop() { queue_results.remove_waiting_task_ids(id_tasks); if (has_next() && !cancelled) { // if tasks is not finished yet, cancel them cancelled = true; std::vector cancel_tasks; cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { SRV_WRN("cancel task, id_task = %d\n", id_task); server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; queue_results.remove_waiting_task_id(id_task); cancel_tasks.push_back(std::move(task)); } // push to beginning of the queue, so it has highest priority queue_tasks.post(std::move(cancel_tasks), true); } else { SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); } }