server-queue.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #pragma once
  2. #include "server-task.h"
  3. #include <condition_variable>
  4. #include <deque>
  5. #include <mutex>
  6. #include <vector>
  7. #include <unordered_set>
  8. // struct for managing server tasks
  9. // in most cases, use server_response_reader to post new tasks and retrieve results
  10. struct server_queue {
  11. private:
  12. int id = 0;
  13. bool running = false;
  14. bool sleeping = false;
  15. bool req_stop_sleeping = false;
  16. int64_t time_last_task = 0;
  17. // queues
  18. std::deque<server_task> queue_tasks;
  19. std::deque<server_task> queue_tasks_deferred;
  20. std::mutex mutex_tasks;
  21. std::condition_variable condition_tasks;
  22. // callback functions
  23. std::function<void(server_task &&)> callback_new_task;
  24. std::function<void(void)> callback_update_slots;
  25. std::function<void(bool)> callback_sleeping_state;
  26. public:
  27. // Add a new task to the end of the queue
  28. int post(server_task && task, bool front = false);
  29. // multi-task version of post()
  30. int post(std::vector<server_task> && tasks, bool front = false);
  31. // Add a new task, but defer until one slot is available
  32. void defer(server_task && task);
  33. // Get the next id for creating a new task
  34. int get_new_id();
  35. // Call when the state of one slot is changed, it will move one task from deferred to main queue
  36. // prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
  37. void pop_deferred_task(int id_slot);
  38. // if sleeping, request exiting sleep state and wait until it is done
  39. // returns immediately if not sleeping
  40. void wait_until_no_sleep();
  41. bool is_sleeping() {
  42. std::unique_lock<std::mutex> lock(mutex_tasks);
  43. return sleeping;
  44. }
  45. // end the start_loop routine
  46. void terminate();
  47. /**
  48. * Main loop consists of these steps:
  49. * - Wait until a new task arrives
  50. * - Process the task (i.e. maybe copy data into slot)
  51. * - Check if multitask is finished
  52. * - Update all slots
  53. *
  54. * Sleeping procedure (disabled if idle_sleep_ms < 0):
  55. * - If there is no task after idle_sleep_ms, enter sleeping state
  56. * - Call callback_sleeping_state(true)
  57. * - Wait until req_stop_sleeping is set to true
  58. * - Call callback_sleeping_state(false)
  59. * - Exit sleeping state
  60. */
  61. void start_loop(int64_t idle_sleep_ms = -1);
  62. // for metrics
  63. size_t queue_tasks_deferred_size() {
  64. std::unique_lock<std::mutex> lock(mutex_tasks);
  65. return queue_tasks_deferred.size();
  66. }
  67. //
  68. // Functions below are not thread-safe, must only be used before start_loop() is called
  69. //
  70. // Register function to process a new task
  71. void on_new_task(std::function<void(server_task &&)> callback) {
  72. callback_new_task = std::move(callback);
  73. }
  74. // Register the function to be called when all slots data is ready to be processed
  75. void on_update_slots(std::function<void(void)> callback) {
  76. callback_update_slots = std::move(callback);
  77. }
  78. // Register callback for sleeping state change
  79. // note: when entering sleeping state, the callback is called AFTER sleeping is set to true
  80. // when leaving sleeping state, the callback is called BEFORE sleeping is set to false
  81. void on_sleeping_state(std::function<void(bool)> callback) {
  82. callback_sleeping_state = std::move(callback);
  83. }
  84. private:
  85. void cleanup_pending_task(int id_target);
  86. };
  87. // struct for managing server responses
  88. // in most cases, use server_response_reader to retrieve results
  89. struct server_response {
  90. private:
  91. bool running = true;
  92. // for keeping track of all tasks waiting for the result
  93. std::unordered_set<int> waiting_task_ids;
  94. // the main result queue (using ptr for polymorphism)
  95. std::vector<server_task_result_ptr> queue_results;
  96. std::mutex mutex_results;
  97. std::condition_variable condition_results;
  98. public:
  99. // add the id_task to the list of tasks waiting for response
  100. void add_waiting_task_id(int id_task);
  101. void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
  102. // when the request is finished, we can remove task associated with it
  103. void remove_waiting_task_id(int id_task);
  104. // remove multiple tasks from waiting list
  105. void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
  106. // This function blocks the thread until there is a response for one of the id_tasks
  107. server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
  108. // same as recv(), but have timeout in seconds
  109. // if timeout is reached, nullptr is returned
  110. server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
  111. // single-task version of recv()
  112. server_task_result_ptr recv(int id_task);
  113. // Send a new result to a waiting id_task
  114. void send(server_task_result_ptr && result);
  115. // terminate the waiting loop
  116. void terminate();
  117. };
  118. // utility class to make working with server_queue and server_response easier
  119. // it provides a generator-like API for server responses
  120. // support pooling connection state and aggregating multiple results
  121. struct server_response_reader {
  122. std::unordered_set<int> id_tasks;
  123. server_queue & queue_tasks;
  124. server_response & queue_results;
  125. size_t received_count = 0;
  126. bool cancelled = false;
  127. int polling_interval_seconds;
  128. // tracking generation state and partial tool calls
  129. // only used by streaming completions
  130. std::vector<task_result_state> states;
  131. // should_stop function will be called each polling_interval_seconds
  132. server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
  133. : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
  134. ~server_response_reader() {
  135. stop();
  136. }
  137. int get_new_id() {
  138. return queue_tasks.get_new_id();
  139. }
  140. // if front = true, the task will be posted to the front of the queue (high priority)
  141. void post_task(server_task && task, bool front = false);
  142. void post_tasks(std::vector<server_task> && tasks, bool front = false);
  143. bool has_next() const;
  144. // return nullptr if should_stop() is true before receiving a result
  145. // note: if one error is received, it will stop further processing and return error result
  146. server_task_result_ptr next(const std::function<bool()> & should_stop);
  147. struct batch_response {
  148. bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
  149. std::vector<server_task_result_ptr> results;
  150. server_task_result_ptr error; // nullptr if no error
  151. };
  152. // aggregate multiple results
  153. batch_response wait_for_all(const std::function<bool()> & should_stop);
  154. void stop();
  155. };