1
0

server-queue.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #pragma once
  2. #include "server-task.h"
  3. #include <condition_variable>
  4. #include <deque>
  5. #include <mutex>
  6. #include <unordered_set>
  7. // struct for managing server tasks
  8. // in most cases, use server_response_reader to post new tasks and retrieve results
  9. struct server_queue {
  10. private:
  11. int id = 0;
  12. bool running;
  13. // queues
  14. std::deque<server_task> queue_tasks;
  15. std::deque<server_task> queue_tasks_deferred;
  16. std::mutex mutex_tasks;
  17. std::condition_variable condition_tasks;
  18. // callback functions
  19. std::function<void(server_task &&)> callback_new_task;
  20. std::function<void(void)> callback_update_slots;
  21. public:
  22. // Add a new task to the end of the queue
  23. int post(server_task && task, bool front = false);
  24. // multi-task version of post()
  25. int post(std::vector<server_task> && tasks, bool front = false);
  26. // Add a new task, but defer until one slot is available
  27. void defer(server_task && task);
  28. // Get the next id for creating a new task
  29. int get_new_id();
  30. // Register function to process a new task
  31. void on_new_task(std::function<void(server_task &&)> callback);
  32. // Register the function to be called when all slots data is ready to be processed
  33. void on_update_slots(std::function<void(void)> callback);
  34. // Call when the state of one slot is changed, it will move one task from deferred to main queue
  35. void pop_deferred_task();
  36. // end the start_loop routine
  37. void terminate();
  38. /**
  39. * Main loop consists of these steps:
  40. * - Wait until a new task arrives
  41. * - Process the task (i.e. maybe copy data into slot)
  42. * - Check if multitask is finished
  43. * - Update all slots
  44. */
  45. void start_loop();
  46. // for metrics
  47. size_t queue_tasks_deferred_size() {
  48. std::unique_lock<std::mutex> lock(mutex_tasks);
  49. return queue_tasks_deferred.size();
  50. }
  51. private:
  52. void cleanup_pending_task(int id_target);
  53. };
  54. // struct for managing server responses
  55. // in most cases, use server_response_reader to retrieve results
  56. struct server_response {
  57. private:
  58. bool running = true;
  59. // for keeping track of all tasks waiting for the result
  60. std::unordered_set<int> waiting_task_ids;
  61. // the main result queue (using ptr for polymorphism)
  62. std::vector<server_task_result_ptr> queue_results;
  63. std::mutex mutex_results;
  64. std::condition_variable condition_results;
  65. public:
  66. // add the id_task to the list of tasks waiting for response
  67. void add_waiting_task_id(int id_task);
  68. void add_waiting_tasks(const std::vector<server_task> & tasks);
  69. // when the request is finished, we can remove task associated with it
  70. void remove_waiting_task_id(int id_task);
  71. // remove multiple tasks from waiting list
  72. void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
  73. // This function blocks the thread until there is a response for one of the id_tasks
  74. server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
  75. // same as recv(), but have timeout in seconds
  76. // if timeout is reached, nullptr is returned
  77. server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
  78. // single-task version of recv()
  79. server_task_result_ptr recv(int id_task);
  80. // Send a new result to a waiting id_task
  81. void send(server_task_result_ptr && result);
  82. // terminate the waiting loop
  83. void terminate();
  84. };
  85. // utility class to make working with server_queue and server_response easier
  86. // it provides a generator-like API for server responses
  87. // support pooling connection state and aggregating multiple results
  88. struct server_response_reader {
  89. std::unordered_set<int> id_tasks;
  90. server_queue & queue_tasks;
  91. server_response & queue_results;
  92. size_t received_count = 0;
  93. bool cancelled = false;
  94. int polling_interval_seconds;
  95. // tracking generation state and partial tool calls
  96. // only used by streaming completions
  97. std::vector<task_result_state> states;
  98. // should_stop function will be called each polling_interval_seconds
  99. server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
  100. : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
  101. ~server_response_reader() {
  102. stop();
  103. }
  104. void post_task(server_task && tasks);
  105. void post_tasks(std::vector<server_task> && tasks);
  106. bool has_next() const;
  107. // return nullptr if should_stop() is true before receiving a result
  108. // note: if one error is received, it will stop further processing and return error result
  109. server_task_result_ptr next(const std::function<bool()> & should_stop);
  110. struct batch_response {
  111. bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
  112. std::vector<server_task_result_ptr> results;
  113. server_task_result_ptr error; // nullptr if no error
  114. };
  115. // aggregate multiple results
  116. batch_response wait_for_all(const std::function<bool()> & should_stop);
  117. void stop();
  118. };