1
0

server-task.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. #pragma once
  2. #include "common.h"
  3. #include "llama.h"
  4. #include <string>
  5. #include <unordered_set>
  6. #include <list>
  7. // TODO: prevent including the whole server-common.h as we only use server_tokens
  8. #include "server-common.h"
  9. using json = nlohmann::ordered_json;
  10. enum server_task_type {
  11. SERVER_TASK_TYPE_COMPLETION,
  12. SERVER_TASK_TYPE_EMBEDDING,
  13. SERVER_TASK_TYPE_RERANK,
  14. SERVER_TASK_TYPE_INFILL,
  15. SERVER_TASK_TYPE_CANCEL,
  16. SERVER_TASK_TYPE_NEXT_RESPONSE,
  17. SERVER_TASK_TYPE_METRICS,
  18. SERVER_TASK_TYPE_SLOT_SAVE,
  19. SERVER_TASK_TYPE_SLOT_RESTORE,
  20. SERVER_TASK_TYPE_SLOT_ERASE,
  21. SERVER_TASK_TYPE_SET_LORA,
  22. };
  23. // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
  24. enum oaicompat_type {
  25. OAICOMPAT_TYPE_NONE,
  26. OAICOMPAT_TYPE_CHAT,
  27. OAICOMPAT_TYPE_COMPLETION,
  28. OAICOMPAT_TYPE_EMBEDDING,
  29. };
  30. enum stop_type {
  31. STOP_TYPE_NONE,
  32. STOP_TYPE_EOS,
  33. STOP_TYPE_WORD,
  34. STOP_TYPE_LIMIT,
  35. };
  36. struct task_params {
  37. bool stream = true;
  38. bool include_usage = false;
  39. bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
  40. bool return_tokens = false;
  41. bool return_progress = false;
  42. int32_t n_keep = 0; // number of tokens to keep from initial prompt
  43. int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
  44. int32_t n_predict = -1; // new tokens to predict
  45. int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
  46. int64_t t_max_prompt_ms = -1; // TODO: implement
  47. int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
  48. std::vector<common_adapter_lora_info> lora;
  49. std::vector<std::string> antiprompt;
  50. std::vector<std::string> response_fields;
  51. bool timings_per_token = false;
  52. bool post_sampling_probs = false;
  53. struct common_params_sampling sampling;
  54. struct common_params_speculative speculative;
  55. // OAI-compat fields
  56. bool verbose = false;
  57. oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
  58. std::string oaicompat_model;
  59. std::string oaicompat_cmpl_id;
  60. common_chat_syntax oaicompat_chat_syntax;
  61. // Embeddings
  62. int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
  63. json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
  64. json to_json(bool only_metrics = false) const;
  65. };
  66. struct server_task {
  67. int id = -1; // to be filled by server_queue
  68. int index = -1; // used when there are multiple prompts (batch request)
  69. // used by SERVER_TASK_TYPE_CANCEL
  70. int id_target = -1;
  71. int id_slot = -1;
  72. // used by SERVER_TASK_TYPE_INFERENCE
  73. task_params params;
  74. server_tokens tokens;
  75. server_task_type type;
  76. // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
  77. struct slot_action {
  78. int slot_id;
  79. std::string filename;
  80. std::string filepath;
  81. };
  82. slot_action slot_action;
  83. // used by SERVER_TASK_TYPE_METRICS
  84. bool metrics_reset_bucket = false;
  85. // used by SERVER_TASK_TYPE_SET_LORA
  86. std::vector<common_adapter_lora_info> set_lora;
  87. server_task() = default;
  88. server_task(server_task_type type) : type(type) {}
  89. int32_t n_tokens() const {
  90. return tokens.size();
  91. }
  92. static task_params params_from_json_cmpl(
  93. const llama_context * ctx,
  94. const common_params & params_base,
  95. const json & data);
  96. // utility function
  97. static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
  98. std::unordered_set<int> ids(tasks.size());
  99. for (size_t i = 0; i < tasks.size(); i++) {
  100. ids.insert(tasks[i].id);
  101. }
  102. return ids;
  103. }
  104. };
  105. struct result_timings {
  106. int32_t cache_n = -1;
  107. int32_t prompt_n = -1;
  108. double prompt_ms;
  109. double prompt_per_token_ms;
  110. double prompt_per_second;
  111. int32_t predicted_n = -1;
  112. double predicted_ms;
  113. double predicted_per_token_ms;
  114. double predicted_per_second;
  115. // Optional speculative metrics - only included when > 0
  116. int32_t draft_n = 0;
  117. int32_t draft_n_accepted = 0;
  118. json to_json() const;
  119. };
  120. struct result_prompt_progress {
  121. int32_t total = 0;
  122. int32_t cache = 0;
  123. int32_t processed = 0;
  124. int64_t time_ms = 0;
  125. json to_json() const;
  126. };
  127. struct server_task_result {
  128. int id = -1;
  129. int id_slot = -1;
  130. virtual bool is_error() {
  131. // only used by server_task_result_error
  132. return false;
  133. }
  134. virtual bool is_stop() {
  135. // only used by server_task_result_cmpl_*
  136. return true;
  137. }
  138. virtual int get_index() {
  139. return -1;
  140. }
  141. virtual json to_json() = 0;
  142. virtual ~server_task_result() = default;
  143. };
  144. // using shared_ptr for polymorphism of server_task_result
  145. using server_task_result_ptr = std::unique_ptr<server_task_result>;
  146. struct completion_token_output {
  147. llama_token tok;
  148. float prob;
  149. std::string text_to_send;
  150. struct prob_info {
  151. llama_token tok;
  152. std::string txt;
  153. float prob;
  154. };
  155. std::vector<prob_info> probs;
  156. json to_json(bool post_sampling_probs) const;
  157. static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
  158. static float logarithm(float x);
  159. static std::vector<unsigned char> str_to_bytes(const std::string & str);
  160. };
  161. struct server_task_result_cmpl_final : server_task_result {
  162. int index = 0;
  163. std::string content;
  164. llama_tokens tokens;
  165. bool stream;
  166. bool include_usage;
  167. result_timings timings;
  168. std::string prompt;
  169. bool truncated;
  170. int32_t n_decoded;
  171. int32_t n_prompt_tokens;
  172. int32_t n_tokens_cached;
  173. bool has_new_line;
  174. std::string stopping_word;
  175. stop_type stop = STOP_TYPE_NONE;
  176. bool post_sampling_probs;
  177. std::vector<completion_token_output> probs_output;
  178. std::vector<std::string> response_fields;
  179. task_params generation_params;
  180. // OAI-compat fields
  181. bool verbose = false;
  182. oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
  183. std::string oaicompat_model;
  184. std::string oaicompat_cmpl_id;
  185. common_chat_msg oaicompat_msg;
  186. std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
  187. virtual int get_index() override {
  188. return index;
  189. }
  190. virtual bool is_stop() override {
  191. return true; // in stream mode, final responses are considered stop
  192. }
  193. virtual json to_json() override;
  194. json to_json_non_oaicompat();
  195. json to_json_oaicompat();
  196. json to_json_oaicompat_chat();
  197. json to_json_oaicompat_chat_stream();
  198. };
  199. struct server_task_result_cmpl_partial : server_task_result {
  200. int index = 0;
  201. std::string content;
  202. llama_tokens tokens;
  203. int32_t n_decoded;
  204. int32_t n_prompt_tokens;
  205. bool post_sampling_probs;
  206. bool is_progress = false;
  207. completion_token_output prob_output;
  208. result_timings timings;
  209. result_prompt_progress progress;
  210. // OAI-compat fields
  211. bool verbose = false;
  212. oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
  213. std::string oaicompat_model;
  214. std::string oaicompat_cmpl_id;
  215. std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
  216. virtual int get_index() override {
  217. return index;
  218. }
  219. virtual bool is_stop() override {
  220. return false; // in stream mode, partial responses are not considered stop
  221. }
  222. virtual json to_json() override;
  223. json to_json_non_oaicompat();
  224. json to_json_oaicompat();
  225. json to_json_oaicompat_chat();
  226. };
  227. struct server_task_result_embd : server_task_result {
  228. int index = 0;
  229. std::vector<std::vector<float>> embedding;
  230. int32_t n_tokens;
  231. // OAI-compat fields
  232. oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
  233. virtual int get_index() override {
  234. return index;
  235. }
  236. virtual json to_json() override;
  237. json to_json_non_oaicompat();
  238. json to_json_oaicompat();
  239. };
  240. struct server_task_result_rerank : server_task_result {
  241. int index = 0;
  242. float score = -1e6;
  243. int32_t n_tokens;
  244. virtual int get_index() override {
  245. return index;
  246. }
  247. virtual json to_json() override;
  248. };
  249. struct server_task_result_error : server_task_result {
  250. int index = 0;
  251. error_type err_type = ERROR_TYPE_SERVER;
  252. std::string err_msg;
  253. // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
  254. int32_t n_prompt_tokens = 0;
  255. int32_t n_ctx = 0;
  256. virtual bool is_error() override {
  257. return true;
  258. }
  259. virtual json to_json() override;
  260. };
  261. struct server_task_result_metrics : server_task_result {
  262. int n_idle_slots;
  263. int n_processing_slots;
  264. int n_tasks_deferred;
  265. int64_t t_start;
  266. // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
  267. uint64_t n_prompt_tokens_processed_total = 0;
  268. uint64_t t_prompt_processing_total = 0;
  269. uint64_t n_tokens_predicted_total = 0;
  270. uint64_t t_tokens_generation_total = 0;
  271. uint64_t n_tokens_max = 0;
  272. uint64_t n_prompt_tokens_processed = 0;
  273. uint64_t t_prompt_processing = 0;
  274. uint64_t n_tokens_predicted = 0;
  275. uint64_t t_tokens_generation = 0;
  276. uint64_t n_decode_total = 0;
  277. uint64_t n_busy_slots_total = 0;
  278. // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
  279. // therefore, we use json to temporarily store the slot.to_json() result
  280. json slots_data = json::array();
  281. virtual json to_json() override;
  282. };
  283. struct server_task_result_slot_save_load : server_task_result {
  284. std::string filename;
  285. bool is_save; // true = save, false = load
  286. size_t n_tokens;
  287. size_t n_bytes;
  288. double t_ms;
  289. virtual json to_json() override;
  290. };
  291. struct server_task_result_slot_erase : server_task_result {
  292. size_t n_erased;
  293. virtual json to_json() override;
  294. };
  295. struct server_task_result_apply_lora : server_task_result {
  296. virtual json to_json() override;
  297. };
  298. struct server_prompt_checkpoint {
  299. llama_pos pos_min;
  300. llama_pos pos_max;
  301. std::vector<uint8_t> data;
  302. size_t size() const {
  303. return data.size();
  304. }
  305. };
  306. struct server_prompt {
  307. server_tokens tokens;
  308. std::vector<uint8_t> data;
  309. std::list<server_prompt_checkpoint> checkpoints;
  310. size_t size() const {
  311. size_t res = data.size();
  312. for (const auto & checkpoint : checkpoints) {
  313. res += checkpoint.size();
  314. }
  315. return res;
  316. }
  317. int n_tokens() const {
  318. return tokens.size();
  319. }
  320. };
  321. struct server_prompt_cache {
  322. server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
  323. this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
  324. this->limit_tokens = limit_tokens;
  325. }
  326. std::list<server_prompt> states;
  327. // in bytes, 0 = no limit
  328. size_t limit_size = 0;
  329. // in tokens, 0 = no limit
  330. size_t limit_tokens = 0;
  331. size_t size() const;
  332. size_t n_tokens() const;
  333. server_prompt * alloc(const server_prompt & prompt, size_t state_size);
  334. bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
  335. void update();
  336. };