server-queue.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. #include "server-task.h"
  2. #include "server-queue.h"
  3. #include "log.h"
  4. #include <chrono>
  5. #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  6. #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  7. #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  8. #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  9. #define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  10. #define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  11. #define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  12. #define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  13. //
  14. // server_queue
  15. //
  16. int server_queue::post(server_task && task, bool front) {
  17. std::unique_lock<std::mutex> lock(mutex_tasks);
  18. GGML_ASSERT(task.id != -1);
  19. // if this is cancel task make sure to clean up pending tasks
  20. if (task.type == SERVER_TASK_TYPE_CANCEL) {
  21. cleanup_pending_task(task.id_target);
  22. }
  23. const int task_id = task.id;
  24. QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
  25. if (front) {
  26. queue_tasks.push_front(std::move(task));
  27. } else {
  28. queue_tasks.push_back(std::move(task));
  29. }
  30. condition_tasks.notify_one();
  31. return task_id;
  32. }
  33. int server_queue::post(std::vector<server_task> && tasks, bool front) {
  34. std::unique_lock<std::mutex> lock(mutex_tasks);
  35. for (auto & task : tasks) {
  36. if (task.id == -1) {
  37. task.id = id++;
  38. }
  39. // if this is cancel task make sure to clean up pending tasks
  40. if (task.type == SERVER_TASK_TYPE_CANCEL) {
  41. cleanup_pending_task(task.id_target);
  42. }
  43. QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
  44. if (front) {
  45. queue_tasks.push_front(std::move(task));
  46. } else {
  47. queue_tasks.push_back(std::move(task));
  48. }
  49. }
  50. condition_tasks.notify_one();
  51. return 0;
  52. }
  53. void server_queue::defer(server_task && task) {
  54. std::unique_lock<std::mutex> lock(mutex_tasks);
  55. QUE_DBG("defer task, id = %d\n", task.id);
  56. queue_tasks_deferred.push_back(std::move(task));
  57. condition_tasks.notify_one();
  58. }
  59. int server_queue::get_new_id() {
  60. std::unique_lock<std::mutex> lock(mutex_tasks);
  61. int new_id = id++;
  62. return new_id;
  63. }
  64. void server_queue::on_new_task(std::function<void(server_task &&)> callback) {
  65. callback_new_task = std::move(callback);
  66. }
  67. void server_queue::on_update_slots(std::function<void(void)> callback) {
  68. callback_update_slots = std::move(callback);
  69. }
  70. void server_queue::pop_deferred_task() {
  71. std::unique_lock<std::mutex> lock(mutex_tasks);
  72. if (!queue_tasks_deferred.empty()) {
  73. queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
  74. queue_tasks_deferred.pop_front();
  75. }
  76. condition_tasks.notify_one();
  77. }
  78. void server_queue::terminate() {
  79. std::unique_lock<std::mutex> lock(mutex_tasks);
  80. running = false;
  81. condition_tasks.notify_all();
  82. }
  83. void server_queue::start_loop() {
  84. running = true;
  85. while (true) {
  86. QUE_DBG("%s", "processing new tasks\n");
  87. while (true) {
  88. std::unique_lock<std::mutex> lock(mutex_tasks);
  89. if (!running) {
  90. QUE_DBG("%s", "terminate\n");
  91. return;
  92. }
  93. if (queue_tasks.empty()) {
  94. lock.unlock();
  95. break;
  96. }
  97. server_task task = std::move(queue_tasks.front());
  98. queue_tasks.pop_front();
  99. lock.unlock();
  100. QUE_DBG("processing task, id = %d\n", task.id);
  101. callback_new_task(std::move(task));
  102. }
  103. // all tasks in the current loop is processed, slots data is now ready
  104. QUE_DBG("%s", "update slots\n");
  105. callback_update_slots();
  106. QUE_DBG("%s", "waiting for new tasks\n");
  107. {
  108. std::unique_lock<std::mutex> lock(mutex_tasks);
  109. if (!running) {
  110. QUE_DBG("%s", "terminate\n");
  111. return;
  112. }
  113. if (queue_tasks.empty()) {
  114. condition_tasks.wait(lock, [&]{
  115. return (!queue_tasks.empty() || !running);
  116. });
  117. }
  118. }
  119. }
  120. }
  121. void server_queue::cleanup_pending_task(int id_target) {
  122. // no need lock because this is called exclusively by post()
  123. auto rm_func = [id_target](const server_task & task) {
  124. return task.id == id_target;
  125. };
  126. queue_tasks.erase(
  127. std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
  128. queue_tasks.end());
  129. queue_tasks_deferred.erase(
  130. std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
  131. queue_tasks_deferred.end());
  132. }
  133. //
  134. // server_response
  135. //
  136. void server_response::add_waiting_task_id(int id_task) {
  137. RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
  138. std::unique_lock<std::mutex> lock(mutex_results);
  139. waiting_task_ids.insert(id_task);
  140. }
  141. void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
  142. std::unique_lock<std::mutex> lock(mutex_results);
  143. for (const auto & task : tasks) {
  144. RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
  145. waiting_task_ids.insert(task.id);
  146. }
  147. }
  148. void server_response::remove_waiting_task_id(int id_task) {
  149. RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
  150. std::unique_lock<std::mutex> lock(mutex_results);
  151. waiting_task_ids.erase(id_task);
  152. // make sure to clean up all pending results
  153. queue_results.erase(
  154. std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
  155. return res->id == id_task;
  156. }),
  157. queue_results.end());
  158. }
  159. void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
  160. std::unique_lock<std::mutex> lock(mutex_results);
  161. for (const auto & id_task : id_tasks) {
  162. RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
  163. waiting_task_ids.erase(id_task);
  164. }
  165. }
  166. server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
  167. while (true) {
  168. std::unique_lock<std::mutex> lock(mutex_results);
  169. condition_results.wait(lock, [&]{
  170. if (!running) {
  171. RES_DBG("%s : queue result stop\n", "recv");
  172. std::terminate(); // we cannot return here since the caller is HTTP code
  173. }
  174. return !queue_results.empty();
  175. });
  176. for (size_t i = 0; i < queue_results.size(); i++) {
  177. if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
  178. server_task_result_ptr res = std::move(queue_results[i]);
  179. queue_results.erase(queue_results.begin() + i);
  180. return res;
  181. }
  182. }
  183. }
  184. // should never reach here
  185. }
  186. server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
  187. while (true) {
  188. std::unique_lock<std::mutex> lock(mutex_results);
  189. for (int i = 0; i < (int) queue_results.size(); i++) {
  190. if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
  191. server_task_result_ptr res = std::move(queue_results[i]);
  192. queue_results.erase(queue_results.begin() + i);
  193. return res;
  194. }
  195. }
  196. std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
  197. if (!running) {
  198. RES_DBG("%s : queue result stop\n", __func__);
  199. std::terminate(); // we cannot return here since the caller is HTTP code
  200. }
  201. if (cr_res == std::cv_status::timeout) {
  202. return nullptr;
  203. }
  204. }
  205. // should never reach here
  206. }
  207. server_task_result_ptr server_response::recv(int id_task) {
  208. std::unordered_set<int> id_tasks = {id_task};
  209. return recv(id_tasks);
  210. }
  211. void server_response::send(server_task_result_ptr && result) {
  212. RES_DBG("sending result for task id = %d\n", result->id);
  213. std::unique_lock<std::mutex> lock(mutex_results);
  214. for (const auto & id_task : waiting_task_ids) {
  215. if (result->id == id_task) {
  216. RES_DBG("task id = %d pushed to result queue\n", result->id);
  217. queue_results.emplace_back(std::move(result));
  218. condition_results.notify_all();
  219. return;
  220. }
  221. }
  222. }
  223. void server_response::terminate() {
  224. running = false;
  225. condition_results.notify_all();
  226. }
  227. //
  228. // server_response_reader
  229. //
  230. void server_response_reader::set_states(std::vector<task_result_state> && states) {
  231. this->states = std::move(states);
  232. }
  233. void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
  234. id_tasks = server_task::get_list_id(tasks);
  235. queue_results.add_waiting_tasks(tasks);
  236. queue_tasks.post(std::move(tasks));
  237. }
  238. bool server_response_reader::has_next() const {
  239. return !cancelled && received_count < id_tasks.size();
  240. }
  241. // return nullptr if should_stop() is true before receiving a result
  242. // note: if one error is received, it will stop further processing and return error result
  243. server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
  244. while (true) {
  245. server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
  246. if (result == nullptr) {
  247. // timeout, check stop condition
  248. if (should_stop()) {
  249. SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
  250. return nullptr;
  251. }
  252. } else {
  253. if (result->is_error()) {
  254. stop(); // cancel remaining tasks
  255. SRV_DBG("%s", "received error result, stopping further processing\n");
  256. return result;
  257. }
  258. if (!states.empty()) {
  259. // update the generation state if needed
  260. size_t idx = result->get_index();
  261. GGML_ASSERT(idx < states.size());
  262. result->update(states[idx]);
  263. }
  264. if (result->is_stop()) {
  265. received_count++;
  266. }
  267. return result;
  268. }
  269. }
  270. // should not reach here
  271. }
  272. server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
  273. batch_response batch_res;
  274. batch_res.results.resize(id_tasks.size());
  275. while (has_next()) {
  276. auto res = next(should_stop);
  277. if (res == nullptr) {
  278. batch_res.is_terminated = true;
  279. return batch_res;
  280. }
  281. if (res->is_error()) {
  282. batch_res.error = std::move(res);
  283. return batch_res;
  284. }
  285. const size_t idx = res->get_index();
  286. GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
  287. GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
  288. batch_res.results[idx] = std::move(res);
  289. }
  290. return batch_res;
  291. }
  292. void server_response_reader::stop() {
  293. queue_results.remove_waiting_task_ids(id_tasks);
  294. if (has_next() && !cancelled) {
  295. // if tasks is not finished yet, cancel them
  296. cancelled = true;
  297. std::vector<server_task> cancel_tasks;
  298. cancel_tasks.reserve(id_tasks.size());
  299. for (const auto & id_task : id_tasks) {
  300. SRV_WRN("cancel task, id_task = %d\n", id_task);
  301. server_task task(SERVER_TASK_TYPE_CANCEL);
  302. task.id_target = id_task;
  303. queue_results.remove_waiting_task_id(id_task);
  304. cancel_tasks.push_back(std::move(task));
  305. }
  306. // push to beginning of the queue, so it has highest priority
  307. queue_tasks.post(std::move(cancel_tasks), true);
  308. } else {
  309. SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
  310. }
  311. }