server-queue.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. #include "server-task.h"
  2. #include "server-queue.h"
  3. #include "log.h"
  4. #include <chrono>
  5. #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  6. #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  7. #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  8. #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  9. #define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  10. #define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  11. #define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  12. #define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  13. //
  14. // server_queue
  15. //
  16. int server_queue::post(server_task && task, bool front) {
  17. std::unique_lock<std::mutex> lock(mutex_tasks);
  18. GGML_ASSERT(task.id != -1);
  19. // if this is cancel task make sure to clean up pending tasks
  20. if (task.type == SERVER_TASK_TYPE_CANCEL) {
  21. cleanup_pending_task(task.id_target);
  22. }
  23. const int task_id = task.id;
  24. QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
  25. if (front) {
  26. queue_tasks.push_front(std::move(task));
  27. } else {
  28. queue_tasks.push_back(std::move(task));
  29. }
  30. time_last_task = ggml_time_ms();
  31. condition_tasks.notify_one();
  32. return task_id;
  33. }
  34. int server_queue::post(std::vector<server_task> && tasks, bool front) {
  35. std::unique_lock<std::mutex> lock(mutex_tasks);
  36. for (auto & task : tasks) {
  37. if (task.id == -1) {
  38. task.id = id++;
  39. }
  40. // if this is cancel task make sure to clean up pending tasks
  41. if (task.type == SERVER_TASK_TYPE_CANCEL) {
  42. cleanup_pending_task(task.id_target);
  43. }
  44. QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
  45. if (front) {
  46. queue_tasks.push_front(std::move(task));
  47. } else {
  48. queue_tasks.push_back(std::move(task));
  49. }
  50. }
  51. time_last_task = ggml_time_ms();
  52. condition_tasks.notify_one();
  53. return 0;
  54. }
  55. void server_queue::defer(server_task && task) {
  56. std::unique_lock<std::mutex> lock(mutex_tasks);
  57. QUE_DBG("defer task, id = %d\n", task.id);
  58. queue_tasks_deferred.push_back(std::move(task));
  59. time_last_task = ggml_time_ms();
  60. condition_tasks.notify_one();
  61. }
  62. int server_queue::get_new_id() {
  63. std::unique_lock<std::mutex> lock(mutex_tasks);
  64. int new_id = id++;
  65. return new_id;
  66. }
  67. void server_queue::pop_deferred_task() {
  68. std::unique_lock<std::mutex> lock(mutex_tasks);
  69. if (!queue_tasks_deferred.empty()) {
  70. queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
  71. queue_tasks_deferred.pop_front();
  72. }
  73. time_last_task = ggml_time_ms();
  74. condition_tasks.notify_one();
  75. }
  76. void server_queue::wait_until_no_sleep() {
  77. std::unique_lock<std::mutex> lock(mutex_tasks);
  78. if (!sleeping) {
  79. return;
  80. } else {
  81. if (!req_stop_sleeping) {
  82. QUE_DBG("%s", "requesting to stop sleeping\n");
  83. req_stop_sleeping = true;
  84. condition_tasks.notify_one(); // only main thread is waiting on this
  85. }
  86. QUE_DBG("%s", "waiting until no sleep\n");
  87. condition_tasks.wait(lock, [&]{
  88. return !sleeping;
  89. });
  90. }
  91. }
  92. void server_queue::terminate() {
  93. std::unique_lock<std::mutex> lock(mutex_tasks);
  94. running = false;
  95. condition_tasks.notify_all();
  96. }
  97. void server_queue::start_loop(int64_t idle_sleep_ms) {
  98. running = true;
  99. time_last_task = ggml_time_ms();
  100. constexpr auto max_wait_time = std::chrono::seconds(1);
  101. auto should_sleep = [&]() -> bool {
  102. // caller must hold mutex_tasks
  103. if (idle_sleep_ms < 0) {
  104. return false;
  105. }
  106. int64_t now = ggml_time_ms();
  107. return (now - time_last_task) >= idle_sleep_ms;
  108. };
  109. while (true) {
  110. QUE_DBG("%s", "processing new tasks\n");
  111. while (true) {
  112. std::unique_lock<std::mutex> lock(mutex_tasks);
  113. if (!running) {
  114. QUE_DBG("%s", "terminate\n");
  115. return;
  116. }
  117. if (queue_tasks.empty()) {
  118. lock.unlock();
  119. break;
  120. }
  121. server_task task = std::move(queue_tasks.front());
  122. queue_tasks.pop_front();
  123. lock.unlock();
  124. QUE_DBG("processing task, id = %d\n", task.id);
  125. callback_new_task(std::move(task));
  126. }
  127. // all tasks in the current loop is processed, slots data is now ready
  128. QUE_DBG("%s", "update slots\n");
  129. // this will run the main inference process for all slots
  130. callback_update_slots();
  131. {
  132. // update_slots() may take a while to finish, we need to make sure it's not counted as idle
  133. std::unique_lock<std::mutex> lock(mutex_tasks);
  134. time_last_task = ggml_time_ms();
  135. }
  136. QUE_DBG("%s", "waiting for new tasks\n");
  137. while (true) {
  138. std::unique_lock<std::mutex> lock(mutex_tasks);
  139. if (!running || !queue_tasks.empty()) {
  140. break; // go back to process new tasks or terminate
  141. }
  142. // no tasks, check for sleeping state
  143. if (should_sleep()) {
  144. QUE_INF("%s", "entering sleeping state\n");
  145. sleeping = true;
  146. callback_sleeping_state(true);
  147. req_stop_sleeping = false;
  148. // wait until we are requested to exit sleeping state
  149. condition_tasks.wait(lock, [&]{
  150. return (!running || req_stop_sleeping);
  151. });
  152. if (!running) { // may changed during sleep
  153. break; // terminate
  154. }
  155. QUE_INF("%s", "exiting sleeping state\n");
  156. req_stop_sleeping = false;
  157. callback_sleeping_state(false);
  158. sleeping = false;
  159. time_last_task = ggml_time_ms();
  160. condition_tasks.notify_all(); // notify wait_until_no_sleep()
  161. break; // process new tasks
  162. } else {
  163. // wait for new tasks or timeout for checking sleeping condition
  164. bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
  165. return (!queue_tasks.empty() || !running);
  166. });
  167. if (res) {
  168. break; // new task arrived or terminate
  169. }
  170. // otherwise, loop again to check sleeping condition
  171. }
  172. }
  173. }
  174. }
  175. void server_queue::cleanup_pending_task(int id_target) {
  176. // no need lock because this is called exclusively by post()
  177. auto rm_func = [id_target](const server_task & task) {
  178. return task.id == id_target;
  179. };
  180. queue_tasks.erase(
  181. std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
  182. queue_tasks.end());
  183. queue_tasks_deferred.erase(
  184. std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
  185. queue_tasks_deferred.end());
  186. }
  187. //
  188. // server_response
  189. //
  190. void server_response::add_waiting_task_id(int id_task) {
  191. RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
  192. std::unique_lock<std::mutex> lock(mutex_results);
  193. waiting_task_ids.insert(id_task);
  194. }
  195. void server_response::add_waiting_tasks(const std::vector<server_task> & tasks) {
  196. std::unique_lock<std::mutex> lock(mutex_results);
  197. for (const auto & task : tasks) {
  198. RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
  199. waiting_task_ids.insert(task.id);
  200. }
  201. }
  202. void server_response::remove_waiting_task_id(int id_task) {
  203. RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
  204. std::unique_lock<std::mutex> lock(mutex_results);
  205. waiting_task_ids.erase(id_task);
  206. // make sure to clean up all pending results
  207. queue_results.erase(
  208. std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
  209. return res->id == id_task;
  210. }),
  211. queue_results.end());
  212. }
  213. void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
  214. std::unique_lock<std::mutex> lock(mutex_results);
  215. for (const auto & id_task : id_tasks) {
  216. RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
  217. waiting_task_ids.erase(id_task);
  218. }
  219. }
  220. server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
  221. while (true) {
  222. std::unique_lock<std::mutex> lock(mutex_results);
  223. condition_results.wait(lock, [&]{
  224. if (!running) {
  225. RES_DBG("%s : queue result stop\n", "recv");
  226. std::terminate(); // we cannot return here since the caller is HTTP code
  227. }
  228. return !queue_results.empty();
  229. });
  230. for (size_t i = 0; i < queue_results.size(); i++) {
  231. if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
  232. server_task_result_ptr res = std::move(queue_results[i]);
  233. queue_results.erase(queue_results.begin() + i);
  234. return res;
  235. }
  236. }
  237. }
  238. // should never reach here
  239. }
  240. server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
  241. while (true) {
  242. std::unique_lock<std::mutex> lock(mutex_results);
  243. for (int i = 0; i < (int) queue_results.size(); i++) {
  244. if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
  245. server_task_result_ptr res = std::move(queue_results[i]);
  246. queue_results.erase(queue_results.begin() + i);
  247. return res;
  248. }
  249. }
  250. std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
  251. if (!running) {
  252. RES_DBG("%s : queue result stop\n", __func__);
  253. std::terminate(); // we cannot return here since the caller is HTTP code
  254. }
  255. if (cr_res == std::cv_status::timeout) {
  256. return nullptr;
  257. }
  258. }
  259. // should never reach here
  260. }
  261. server_task_result_ptr server_response::recv(int id_task) {
  262. std::unordered_set<int> id_tasks = {id_task};
  263. return recv(id_tasks);
  264. }
  265. void server_response::send(server_task_result_ptr && result) {
  266. RES_DBG("sending result for task id = %d\n", result->id);
  267. std::unique_lock<std::mutex> lock(mutex_results);
  268. for (const auto & id_task : waiting_task_ids) {
  269. if (result->id == id_task) {
  270. RES_DBG("task id = %d pushed to result queue\n", result->id);
  271. queue_results.emplace_back(std::move(result));
  272. condition_results.notify_all();
  273. return;
  274. }
  275. }
  276. }
  277. void server_response::terminate() {
  278. running = false;
  279. condition_results.notify_all();
  280. }
  281. //
  282. // server_response_reader
  283. //
  284. void server_response_reader::post_task(server_task && task, bool front) {
  285. GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
  286. task.index = 0;
  287. id_tasks.insert(task.id);
  288. states.push_back(task.create_state());
  289. queue_results.add_waiting_task_id(task.id);
  290. queue_tasks.post(std::move(task), front);
  291. }
  292. void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
  293. GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
  294. id_tasks = server_task::get_list_id(tasks);
  295. states.reserve(tasks.size());
  296. for (size_t i = 0; i < tasks.size(); i++) {
  297. tasks[i].index = i;
  298. states.push_back(tasks[i].create_state());
  299. }
  300. queue_results.add_waiting_tasks(tasks);
  301. queue_tasks.post(std::move(tasks), front);
  302. }
  303. bool server_response_reader::has_next() const {
  304. return !cancelled && received_count < id_tasks.size();
  305. }
  306. // return nullptr if should_stop() is true before receiving a result
  307. // note: if one error is received, it will stop further processing and return error result
  308. server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
  309. while (true) {
  310. server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
  311. if (result == nullptr) {
  312. // timeout, check stop condition
  313. if (should_stop()) {
  314. SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
  315. return nullptr;
  316. }
  317. } else {
  318. if (result->is_error()) {
  319. stop(); // cancel remaining tasks
  320. SRV_DBG("%s", "received error result, stopping further processing\n");
  321. return result;
  322. }
  323. if (!states.empty()) {
  324. // update the generation state if needed
  325. const size_t idx = result->index;
  326. GGML_ASSERT(idx < states.size());
  327. result->update(states[idx]);
  328. }
  329. if (result->is_stop()) {
  330. received_count++;
  331. }
  332. return result;
  333. }
  334. }
  335. // should not reach here
  336. }
  337. server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
  338. batch_response batch_res;
  339. batch_res.results.clear();
  340. batch_res.results.resize(id_tasks.size());
  341. while (has_next()) {
  342. auto res = next(should_stop);
  343. if (res == nullptr) {
  344. batch_res.is_terminated = true;
  345. return batch_res;
  346. }
  347. if (res->is_error()) {
  348. batch_res.error = std::move(res);
  349. return batch_res;
  350. }
  351. const size_t idx = res->index;
  352. GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
  353. GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
  354. batch_res.results[idx] = std::move(res);
  355. }
  356. return batch_res;
  357. }
  358. void server_response_reader::stop() {
  359. queue_results.remove_waiting_task_ids(id_tasks);
  360. if (has_next() && !cancelled) {
  361. // if tasks is not finished yet, cancel them
  362. cancelled = true;
  363. std::vector<server_task> cancel_tasks;
  364. cancel_tasks.reserve(id_tasks.size());
  365. for (const auto & id_task : id_tasks) {
  366. SRV_WRN("cancel task, id_task = %d\n", id_task);
  367. server_task task(SERVER_TASK_TYPE_CANCEL);
  368. task.id_target = id_task;
  369. queue_results.remove_waiting_task_id(id_task);
  370. cancel_tasks.push_back(std::move(task));
  371. }
  372. // push to beginning of the queue, so it has highest priority
  373. queue_tasks.post(std::move(cancel_tasks), true);
  374. } else {
  375. SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
  376. }
  377. }