1
0

utils.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. #pragma once
  2. #include <string>
  3. #include <vector>
  4. #include <set>
  5. #include <mutex>
  6. #include <condition_variable>
  7. #include <unordered_map>
  8. #include "json.hpp"
  9. #include "../llava/clip.h"
  10. using json = nlohmann::json;
  11. extern bool server_verbose;
  12. extern bool server_log_json;
  13. #ifndef SERVER_VERBOSE
  14. #define SERVER_VERBOSE 1
  15. #endif
  16. #if SERVER_VERBOSE != 1
  17. #define LOG_VERBOSE(MSG, ...)
  18. #else
  19. #define LOG_VERBOSE(MSG, ...) \
  20. do \
  21. { \
  22. if (server_verbose) \
  23. { \
  24. server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
  25. } \
  26. } while (0)
  27. #endif
  28. #define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
  29. #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
  30. #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
  31. //
  32. // parallel
  33. //
  34. enum server_state {
  35. SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
  36. SERVER_STATE_READY, // Server is ready and model is loaded
  37. SERVER_STATE_ERROR // An error occurred, load_model failed
  38. };
  39. enum task_type {
  40. TASK_TYPE_COMPLETION,
  41. TASK_TYPE_CANCEL,
  42. TASK_TYPE_NEXT_RESPONSE,
  43. TASK_TYPE_METRICS
  44. };
  45. struct task_server {
  46. int id = -1; // to be filled by llama_server_queue
  47. int target_id;
  48. task_type type;
  49. json data;
  50. bool infill_mode = false;
  51. bool embedding_mode = false;
  52. int multitask_id = -1;
  53. };
  54. struct task_result {
  55. int id;
  56. int multitask_id = -1;
  57. bool stop;
  58. bool error;
  59. json result_json;
  60. };
  61. struct task_multi {
  62. int id;
  63. std::set<int> subtasks_remaining{};
  64. std::vector<task_result> results{};
  65. };
  66. // TODO: can become bool if we can't find use of more states
  67. enum slot_state
  68. {
  69. IDLE,
  70. PROCESSING,
  71. };
  72. enum slot_command
  73. {
  74. NONE,
  75. LOAD_PROMPT,
  76. RELEASE,
  77. };
  78. struct slot_params
  79. {
  80. bool stream = true;
  81. bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
  82. uint32_t seed = -1; // RNG seed
  83. int32_t n_keep = 0; // number of tokens to keep from initial prompt
  84. int32_t n_predict = -1; // new tokens to predict
  85. std::vector<std::string> antiprompt;
  86. json input_prefix;
  87. json input_suffix;
  88. };
  89. struct slot_image
  90. {
  91. int32_t id;
  92. bool request_encode_image = false;
  93. float * image_embedding = nullptr;
  94. int32_t image_tokens = 0;
  95. clip_image_u8 * img_data;
  96. std::string prefix_prompt; // before of this image
  97. };
  98. // completion token output with probabilities
  99. struct completion_token_output
  100. {
  101. struct token_prob
  102. {
  103. llama_token tok;
  104. float prob;
  105. };
  106. std::vector<token_prob> probs;
  107. llama_token tok;
  108. std::string text_to_send;
  109. };
  110. static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra)
  111. {
  112. std::stringstream ss_tid;
  113. ss_tid << std::this_thread::get_id();
  114. json log = nlohmann::ordered_json{
  115. {"tid", ss_tid.str()},
  116. {"timestamp", time(nullptr)},
  117. };
  118. if (server_log_json) {
  119. log.merge_patch(
  120. {
  121. {"level", level},
  122. {"function", function},
  123. {"line", line},
  124. {"msg", message},
  125. });
  126. if (!extra.empty()) {
  127. log.merge_patch(extra);
  128. }
  129. std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
  130. } else {
  131. char buf[1024];
  132. snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
  133. if (!extra.empty()) {
  134. log.merge_patch(extra);
  135. }
  136. std::stringstream ss;
  137. ss << buf << " |";
  138. for (const auto& el : log.items())
  139. {
  140. const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
  141. snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str());
  142. ss << buf;
  143. }
  144. const std::string str = ss.str();
  145. printf("%.*s\n", (int)str.size(), str.data());
  146. fflush(stdout);
  147. }
  148. }
  149. //
  150. // server utils
  151. //
  152. template <typename T>
  153. static T json_value(const json &body, const std::string &key, const T &default_value)
  154. {
  155. // Fallback null to default value
  156. return body.contains(key) && !body.at(key).is_null()
  157. ? body.value(key, default_value)
  158. : default_value;
  159. }
  160. // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
  161. inline bool verify_custom_template(const std::string & tmpl) {
  162. llama_chat_message chat[] = {{"user", "test"}};
  163. std::vector<char> buf(1);
  164. int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
  165. return res >= 0;
  166. }
  167. // Format given chat. If tmpl is empty, we take the template from model metadata
  168. inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages)
  169. {
  170. size_t alloc_size = 0;
  171. // vector holding all allocated string to be passed to llama_chat_apply_template
  172. std::vector<std::string> str(messages.size() * 2);
  173. std::vector<llama_chat_message> chat(messages.size());
  174. for (size_t i = 0; i < messages.size(); ++i) {
  175. auto &curr_msg = messages[i];
  176. str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
  177. str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
  178. alloc_size += str[i*2 + 1].length();
  179. chat[i].role = str[i*2 + 0].c_str();
  180. chat[i].content = str[i*2 + 1].c_str();
  181. }
  182. const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
  183. std::vector<char> buf(alloc_size * 2);
  184. // run the first time to get the total output length
  185. int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  186. // if it turns out that our buffer is too small, we resize it
  187. if ((size_t) res > buf.size()) {
  188. buf.resize(res);
  189. res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
  190. }
  191. std::string formatted_chat(buf.data(), res);
  192. LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
  193. return formatted_chat;
  194. }
  195. //
  196. // work queue utils
  197. //
  198. struct llama_server_queue {
  199. int id = 0;
  200. std::mutex mutex_tasks;
  201. bool running;
  202. // queues
  203. std::vector<task_server> queue_tasks;
  204. std::vector<task_server> queue_tasks_deferred;
  205. std::vector<task_multi> queue_multitasks;
  206. std::condition_variable condition_tasks;
  207. // callback functions
  208. std::function<void(task_server&)> callback_new_task;
  209. std::function<void(task_multi&)> callback_finish_multitask;
  210. std::function<void(void)> callback_all_task_finished;
  211. // Add a new task to the end of the queue
  212. int post(task_server task) {
  213. std::unique_lock<std::mutex> lock(mutex_tasks);
  214. if (task.id == -1) {
  215. task.id = id++;
  216. LOG_VERBOSE("new task id", {{"new_id", task.id}});
  217. }
  218. queue_tasks.push_back(std::move(task));
  219. condition_tasks.notify_one();
  220. return task.id;
  221. }
  222. // Add a new task, but defer until one slot is available
  223. void defer(task_server task) {
  224. std::unique_lock<std::mutex> lock(mutex_tasks);
  225. queue_tasks_deferred.push_back(std::move(task));
  226. }
  227. // Get the next id for creating anew task
  228. int get_new_id() {
  229. std::unique_lock<std::mutex> lock(mutex_tasks);
  230. int new_id = id++;
  231. LOG_VERBOSE("new task id", {{"new_id", new_id}});
  232. return new_id;
  233. }
  234. // Register function to process a new task
  235. void on_new_task(std::function<void(task_server&)> callback) {
  236. callback_new_task = callback;
  237. }
  238. // Register function to process a multitask
  239. void on_finish_multitask(std::function<void(task_multi&)> callback) {
  240. callback_finish_multitask = callback;
  241. }
  242. // Register the function to be called when the batch of tasks is finished
  243. void on_all_tasks_finished(std::function<void(void)> callback) {
  244. callback_all_task_finished = callback;
  245. }
  246. // Call when the state of one slot is changed
  247. void notify_slot_changed() {
  248. // move deferred tasks back to main loop
  249. std::unique_lock<std::mutex> lock(mutex_tasks);
  250. for (auto & task : queue_tasks_deferred) {
  251. queue_tasks.push_back(std::move(task));
  252. }
  253. queue_tasks_deferred.clear();
  254. }
  255. // end the start_loop routine
  256. void terminate() {
  257. {
  258. std::unique_lock<std::mutex> lock(mutex_tasks);
  259. running = false;
  260. }
  261. condition_tasks.notify_all();
  262. }
  263. // Start the main loop.
  264. void start_loop() {
  265. running = true;
  266. while (true) {
  267. LOG_VERBOSE("new task may arrive", {});
  268. {
  269. while (true)
  270. {
  271. std::unique_lock<std::mutex> lock(mutex_tasks);
  272. if (queue_tasks.empty()) {
  273. lock.unlock();
  274. break;
  275. }
  276. task_server task = queue_tasks.front();
  277. queue_tasks.erase(queue_tasks.begin());
  278. lock.unlock();
  279. LOG_VERBOSE("callback_new_task", {{"task_id", task.id}});
  280. callback_new_task(task);
  281. }
  282. LOG_VERBOSE("callback_all_task_finished", {});
  283. // process and update all the multitasks
  284. auto queue_iterator = queue_multitasks.begin();
  285. while (queue_iterator != queue_multitasks.end())
  286. {
  287. if (queue_iterator->subtasks_remaining.empty())
  288. {
  289. // all subtasks done == multitask is done
  290. task_multi current_multitask = *queue_iterator;
  291. callback_finish_multitask(current_multitask);
  292. // remove this multitask
  293. queue_iterator = queue_multitasks.erase(queue_iterator);
  294. }
  295. else
  296. {
  297. ++queue_iterator;
  298. }
  299. }
  300. // all tasks in the current loop is finished
  301. callback_all_task_finished();
  302. }
  303. LOG_VERBOSE("wait for new task", {});
  304. // wait for new task
  305. {
  306. std::unique_lock<std::mutex> lock(mutex_tasks);
  307. if (queue_tasks.empty()) {
  308. if (!running) {
  309. LOG_VERBOSE("ending start_loop", {});
  310. return;
  311. }
  312. condition_tasks.wait(lock, [&]{
  313. return (!queue_tasks.empty() || !running);
  314. });
  315. }
  316. }
  317. }
  318. }
  319. //
  320. // functions to manage multitasks
  321. //
  322. // add a multitask by specifying the id of all subtask (subtask is a task_server)
  323. void add_multitask(int multitask_id, std::vector<int>& sub_ids)
  324. {
  325. std::lock_guard<std::mutex> lock(mutex_tasks);
  326. task_multi multi;
  327. multi.id = multitask_id;
  328. std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
  329. queue_multitasks.push_back(multi);
  330. }
  331. // updatethe remaining subtasks, while appending results to multitask
  332. void update_multitask(int multitask_id, int subtask_id, task_result& result)
  333. {
  334. std::lock_guard<std::mutex> lock(mutex_tasks);
  335. for (auto& multitask : queue_multitasks)
  336. {
  337. if (multitask.id == multitask_id)
  338. {
  339. multitask.subtasks_remaining.erase(subtask_id);
  340. multitask.results.push_back(result);
  341. }
  342. }
  343. }
  344. };
  345. struct llama_server_response {
  346. typedef std::function<void(int, int, task_result&)> callback_multitask_t;
  347. callback_multitask_t callback_update_multitask;
  348. // for keeping track of all tasks waiting for the result
  349. std::set<int> waiting_task_ids;
  350. // the main result queue
  351. std::vector<task_result> queue_results;
  352. std::mutex mutex_results;
  353. std::condition_variable condition_results;
  354. void add_waiting_task_id(int task_id) {
  355. LOG_VERBOSE("waiting for task id", {{"task_id", task_id}});
  356. std::unique_lock<std::mutex> lock(mutex_results);
  357. waiting_task_ids.insert(task_id);
  358. }
  359. void remove_waiting_task_id(int task_id) {
  360. LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}});
  361. std::unique_lock<std::mutex> lock(mutex_results);
  362. waiting_task_ids.erase(task_id);
  363. }
  364. // This function blocks the thread until there is a response for this task_id
  365. task_result recv(int task_id) {
  366. while (true)
  367. {
  368. std::unique_lock<std::mutex> lock(mutex_results);
  369. condition_results.wait(lock, [&]{
  370. return !queue_results.empty();
  371. });
  372. for (int i = 0; i < (int) queue_results.size(); i++)
  373. {
  374. if (queue_results[i].id == task_id)
  375. {
  376. assert(queue_results[i].multitask_id == -1);
  377. task_result res = queue_results[i];
  378. queue_results.erase(queue_results.begin() + i);
  379. return res;
  380. }
  381. }
  382. }
  383. // should never reach here
  384. }
  385. // Register the function to update multitask
  386. void on_multitask_update(callback_multitask_t callback) {
  387. callback_update_multitask = callback;
  388. }
  389. // Send a new result to a waiting task_id
  390. void send(task_result result) {
  391. std::unique_lock<std::mutex> lock(mutex_results);
  392. LOG_VERBOSE("send new result", {{"task_id", result.id}});
  393. for (auto& task_id : waiting_task_ids) {
  394. // LOG_TEE("waiting task id %i \n", task_id);
  395. // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
  396. if (result.multitask_id == task_id)
  397. {
  398. LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}});
  399. callback_update_multitask(task_id, result.id, result);
  400. continue;
  401. }
  402. if (result.id == task_id)
  403. {
  404. LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}});
  405. queue_results.push_back(result);
  406. condition_results.notify_all();
  407. return;
  408. }
  409. }
  410. }
  411. };
  412. //
  413. // base64 utils (TODO: move to common in the future)
  414. //
  415. static const std::string base64_chars =
  416. "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  417. "abcdefghijklmnopqrstuvwxyz"
  418. "0123456789+/";
  419. static inline bool is_base64(uint8_t c)
  420. {
  421. return (isalnum(c) || (c == '+') || (c == '/'));
  422. }
  423. static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
  424. {
  425. int i = 0;
  426. int j = 0;
  427. int in_ = 0;
  428. int in_len = encoded_string.size();
  429. uint8_t char_array_4[4];
  430. uint8_t char_array_3[3];
  431. std::vector<uint8_t> ret;
  432. while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
  433. {
  434. char_array_4[i++] = encoded_string[in_]; in_++;
  435. if (i == 4)
  436. {
  437. for (i = 0; i <4; i++)
  438. {
  439. char_array_4[i] = base64_chars.find(char_array_4[i]);
  440. }
  441. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  442. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  443. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  444. for (i = 0; (i < 3); i++)
  445. {
  446. ret.push_back(char_array_3[i]);
  447. }
  448. i = 0;
  449. }
  450. }
  451. if (i)
  452. {
  453. for (j = i; j <4; j++)
  454. {
  455. char_array_4[j] = 0;
  456. }
  457. for (j = 0; j <4; j++)
  458. {
  459. char_array_4[j] = base64_chars.find(char_array_4[j]);
  460. }
  461. char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
  462. char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
  463. char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
  464. for (j = 0; (j < i - 1); j++)
  465. {
  466. ret.push_back(char_array_3[j]);
  467. }
  468. }
  469. return ret;
  470. }
  471. //
  472. // random string / id
  473. //
  474. static std::string random_string()
  475. {
  476. static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
  477. std::random_device rd;
  478. std::mt19937 generator(rd());
  479. std::string result(32, ' ');
  480. for (int i = 0; i < 32; ++i) {
  481. result[i] = str[generator() % str.size()];
  482. }
  483. return result;
  484. }
  485. static std::string gen_chatcmplid()
  486. {
  487. std::stringstream chatcmplid;
  488. chatcmplid << "chatcmpl-" << random_string();
  489. return chatcmplid.str();
  490. }