#pragma once #include "server-task.h" #include #include #include #include // struct for managing server tasks // in most cases, use server_response_reader to post new tasks and retrieve results struct server_queue { private: int id = 0; bool running; // queues std::deque queue_tasks; std::deque queue_tasks_deferred; std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions std::function callback_new_task; std::function callback_update_slots; public: // Add a new task to the end of the queue int post(server_task && task, bool front = false); // multi-task version of post() int post(std::vector && tasks, bool front = false); // Add a new task, but defer until one slot is available void defer(server_task && task); // Get the next id for creating a new task int get_new_id(); // Register function to process a new task void on_new_task(std::function callback); // Register the function to be called when all slots data is ready to be processed void on_update_slots(std::function callback); // Call when the state of one slot is changed, it will move one task from deferred to main queue void pop_deferred_task(); // end the start_loop routine void terminate(); /** * Main loop consists of these steps: * - Wait until a new task arrives * - Process the task (i.e. maybe copy data into slot) * - Check if multitask is finished * - Update all slots */ void start_loop(); // for metrics size_t queue_tasks_deferred_size() { std::unique_lock lock(mutex_tasks); return queue_tasks_deferred.size(); } private: void cleanup_pending_task(int id_target); }; // struct for managing server responses // in most cases, use server_response_reader to retrieve results struct server_response { private: bool running = true; // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; // the main result queue (using ptr for polymorphism) std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; public: // add the id_task to the list of tasks waiting for response void add_waiting_task_id(int id_task); void add_waiting_tasks(const std::vector & tasks); // when the request is finished, we can remove task associated with it void remove_waiting_task_id(int id_task); // remove multiple tasks from waiting list void remove_waiting_task_ids(const std::unordered_set & id_tasks); // This function blocks the thread until there is a response for one of the id_tasks server_task_result_ptr recv(const std::unordered_set & id_tasks); // same as recv(), but have timeout in seconds // if timeout is reached, nullptr is returned server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout); // single-task version of recv() server_task_result_ptr recv(int id_task); // Send a new result to a waiting id_task void send(server_task_result_ptr && result); // terminate the waiting loop void terminate(); }; // utility class to make working with server_queue and server_response easier // it provides a generator-like API for server responses // support pooling connection state and aggregating multiple results struct server_response_reader { std::unordered_set id_tasks; server_queue & queue_tasks; server_response & queue_results; size_t received_count = 0; bool cancelled = false; int polling_interval_seconds; // tracking generation state and partial tool calls // only used by streaming completions std::vector states; // should_stop function will be called each polling_interval_seconds server_response_reader(std::pair server_queues, int polling_interval_seconds) : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} ~server_response_reader() { stop(); } void set_states(std::vector && states); void post_tasks(std::vector && tasks); bool has_next() const; // return nullptr if should_stop() is true before receiving a result // note: if one error is received, it will stop further processing and return error result server_task_result_ptr next(const std::function & should_stop); struct batch_response { bool is_terminated = false; // if true, indicates that processing was stopped before all results were received std::vector results; server_task_result_ptr error; // nullptr if no error }; // aggregate multiple results batch_response wait_for_all(const std::function & should_stop); void stop(); };