| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- #pragma once
- #include "server-task.h"
- #include <condition_variable>
- #include <deque>
- #include <mutex>
- #include <unordered_set>
- // struct for managing server tasks
- // in most cases, use server_response_reader to post new tasks and retrieve results
- struct server_queue {
- private:
- int id = 0;
- bool running;
- // queues
- std::deque<server_task> queue_tasks;
- std::deque<server_task> queue_tasks_deferred;
- std::mutex mutex_tasks;
- std::condition_variable condition_tasks;
- // callback functions
- std::function<void(server_task &&)> callback_new_task;
- std::function<void(void)> callback_update_slots;
- public:
- // Add a new task to the end of the queue
- int post(server_task && task, bool front = false);
- // multi-task version of post()
- int post(std::vector<server_task> && tasks, bool front = false);
- // Add a new task, but defer until one slot is available
- void defer(server_task && task);
- // Get the next id for creating a new task
- int get_new_id();
- // Register function to process a new task
- void on_new_task(std::function<void(server_task &&)> callback);
- // Register the function to be called when all slots data is ready to be processed
- void on_update_slots(std::function<void(void)> callback);
- // Call when the state of one slot is changed, it will move one task from deferred to main queue
- void pop_deferred_task();
- // end the start_loop routine
- void terminate();
- /**
- * Main loop consists of these steps:
- * - Wait until a new task arrives
- * - Process the task (i.e. maybe copy data into slot)
- * - Check if multitask is finished
- * - Update all slots
- */
- void start_loop();
- // for metrics
- size_t queue_tasks_deferred_size() {
- std::unique_lock<std::mutex> lock(mutex_tasks);
- return queue_tasks_deferred.size();
- }
- private:
- void cleanup_pending_task(int id_target);
- };
- // struct for managing server responses
- // in most cases, use server_response_reader to retrieve results
- struct server_response {
- private:
- bool running = true;
- // for keeping track of all tasks waiting for the result
- std::unordered_set<int> waiting_task_ids;
- // the main result queue (using ptr for polymorphism)
- std::vector<server_task_result_ptr> queue_results;
- std::mutex mutex_results;
- std::condition_variable condition_results;
- public:
- // add the id_task to the list of tasks waiting for response
- void add_waiting_task_id(int id_task);
- void add_waiting_tasks(const std::vector<server_task> & tasks);
- // when the request is finished, we can remove task associated with it
- void remove_waiting_task_id(int id_task);
- // remove multiple tasks from waiting list
- void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
- // This function blocks the thread until there is a response for one of the id_tasks
- server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
- // same as recv(), but have timeout in seconds
- // if timeout is reached, nullptr is returned
- server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
- // single-task version of recv()
- server_task_result_ptr recv(int id_task);
- // Send a new result to a waiting id_task
- void send(server_task_result_ptr && result);
- // terminate the waiting loop
- void terminate();
- };
- // utility class to make working with server_queue and server_response easier
- // it provides a generator-like API for server responses
- // support pooling connection state and aggregating multiple results
- struct server_response_reader {
- std::unordered_set<int> id_tasks;
- server_queue & queue_tasks;
- server_response & queue_results;
- size_t received_count = 0;
- bool cancelled = false;
- int polling_interval_seconds;
- // tracking generation state and partial tool calls
- // only used by streaming completions
- std::vector<task_result_state> states;
- // should_stop function will be called each polling_interval_seconds
- server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
- : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
- ~server_response_reader() {
- stop();
- }
- void set_states(std::vector<task_result_state> && states);
- void post_tasks(std::vector<server_task> && tasks);
- bool has_next() const;
- // return nullptr if should_stop() is true before receiving a result
- // note: if one error is received, it will stop further processing and return error result
- server_task_result_ptr next(const std::function<bool()> & should_stop);
- struct batch_response {
- bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
- std::vector<server_task_result_ptr> results;
- server_task_result_ptr error; // nullptr if no error
- };
- // aggregate multiple results
- batch_response wait_for_all(const std::function<bool()> & should_stop);
- void stop();
- };
|