server-task.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. #pragma once
  2. #include "common.h"
  3. #include "llama.h"
  4. #include <string>
  5. #include <unordered_set>
  6. #include <list>
  7. // TODO: prevent including the whole server-common.h as we only use server_tokens
  8. #include "server-common.h"
  9. using json = nlohmann::ordered_json;
  10. enum server_task_type {
  11. SERVER_TASK_TYPE_COMPLETION,
  12. SERVER_TASK_TYPE_EMBEDDING,
  13. SERVER_TASK_TYPE_RERANK,
  14. SERVER_TASK_TYPE_INFILL,
  15. SERVER_TASK_TYPE_CANCEL,
  16. SERVER_TASK_TYPE_NEXT_RESPONSE,
  17. SERVER_TASK_TYPE_METRICS,
  18. SERVER_TASK_TYPE_SLOT_SAVE,
  19. SERVER_TASK_TYPE_SLOT_RESTORE,
  20. SERVER_TASK_TYPE_SLOT_ERASE,
  21. SERVER_TASK_TYPE_SET_LORA,
  22. };
  23. // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
  24. enum task_response_type {
  25. TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
  26. TASK_RESPONSE_TYPE_OAI_CHAT,
  27. TASK_RESPONSE_TYPE_OAI_CMPL,
  28. TASK_RESPONSE_TYPE_OAI_EMBD,
  29. TASK_RESPONSE_TYPE_ANTHROPIC,
  30. };
  31. enum stop_type {
  32. STOP_TYPE_NONE,
  33. STOP_TYPE_EOS,
  34. STOP_TYPE_WORD,
  35. STOP_TYPE_LIMIT,
  36. };
  37. struct task_params {
  38. bool stream = true;
  39. bool include_usage = false;
  40. bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
  41. bool return_tokens = false;
  42. bool return_progress = false;
  43. int32_t n_keep = 0; // number of tokens to keep from initial prompt
  44. int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
  45. int32_t n_predict = -1; // new tokens to predict
  46. int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
  47. int32_t n_cmpl = 1; // number of completions to generate from this prompt
  48. int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
  49. int64_t t_max_prompt_ms = -1; // TODO: implement
  50. int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
  51. std::vector<common_adapter_lora_info> lora;
  52. std::vector<std::string> antiprompt;
  53. std::vector<std::string> response_fields;
  54. bool timings_per_token = false;
  55. bool post_sampling_probs = false;
  56. struct common_params_sampling sampling;
  57. struct common_params_speculative speculative;
  58. // response formatting
  59. bool verbose = false;
  60. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  61. std::string oaicompat_model;
  62. std::string oaicompat_cmpl_id;
  63. common_chat_syntax oaicompat_chat_syntax;
  64. // Embeddings
  65. int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
  66. json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
  67. json to_json(bool only_metrics = false) const;
  68. };
  69. // struct for tracking the state of a task (e.g., for streaming)
  70. struct task_result_state {
  71. // tracking diffs for partial tool calls
  72. std::vector<common_chat_msg_diff> diffs;
  73. common_chat_syntax oaicompat_chat_syntax;
  74. common_chat_msg chat_msg;
  75. std::string generated_text; // append new chunks of generated text here
  76. std::vector<std::string> generated_tool_call_ids;
  77. task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
  78. : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
  79. // parse partial tool calls and update the internal state
  80. common_chat_msg update_chat_msg(
  81. const std::string & text_added,
  82. bool is_partial,
  83. std::vector<common_chat_msg_diff> & diffs);
  84. };
  85. struct server_task {
  86. int id = -1; // to be filled by server_queue
  87. int index = -1; // used when there are multiple prompts (batch request)
  88. // used by SERVER_TASK_TYPE_CANCEL
  89. int id_target = -1;
  90. int id_slot = -1;
  91. // used by parallel sampling (multiple completions from same prompt)
  92. size_t n_children = 0; // number of tasks reusing this prompt
  93. int id_parent = -1;
  94. // used by SERVER_TASK_TYPE_INFERENCE
  95. task_params params;
  96. server_tokens tokens;
  97. // only used by CLI, this delegates the tokenization to the server
  98. json cli_input = nullptr;
  99. std::vector<raw_buffer> cli_files;
  100. server_task_type type;
  101. // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
  102. struct slot_action {
  103. int slot_id;
  104. std::string filename;
  105. std::string filepath;
  106. };
  107. slot_action slot_action;
  108. // used by SERVER_TASK_TYPE_METRICS
  109. bool metrics_reset_bucket = false;
  110. // used by SERVER_TASK_TYPE_SET_LORA
  111. std::vector<common_adapter_lora_info> set_lora;
  112. server_task() = default;
  113. server_task(server_task_type type) : type(type) {}
  114. int32_t n_tokens() const {
  115. return tokens.size();
  116. }
  117. static task_params params_from_json_cmpl(
  118. const llama_context * ctx,
  119. const common_params & params_base,
  120. const json & data);
  121. // utility function
  122. static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
  123. std::unordered_set<int> ids(tasks.size());
  124. for (size_t i = 0; i < tasks.size(); i++) {
  125. ids.insert(tasks[i].id);
  126. }
  127. return ids;
  128. }
  129. server_task create_child(int id_parent, int id_child, int idx) const {
  130. server_task copy;
  131. copy.id = id_child;
  132. copy.index = idx;
  133. copy.id_parent = id_parent;
  134. copy.params = params;
  135. copy.type = type;
  136. copy.tokens = tokens.clone();
  137. return copy;
  138. }
  139. // the task will be moved into queue, then onto slots
  140. // however, the state must be kept by caller (e.g., HTTP thread)
  141. task_result_state create_state() const {
  142. return task_result_state(params.oaicompat_chat_syntax);
  143. }
  144. };
  145. struct result_timings {
  146. int32_t cache_n = -1;
  147. int32_t prompt_n = -1;
  148. double prompt_ms;
  149. double prompt_per_token_ms;
  150. double prompt_per_second;
  151. int32_t predicted_n = -1;
  152. double predicted_ms;
  153. double predicted_per_token_ms;
  154. double predicted_per_second;
  155. // Optional speculative metrics - only included when > 0
  156. int32_t draft_n = 0;
  157. int32_t draft_n_accepted = 0;
  158. json to_json() const;
  159. };
  160. struct result_prompt_progress {
  161. int32_t total = 0;
  162. int32_t cache = 0;
  163. int32_t processed = 0;
  164. int64_t time_ms = 0;
  165. json to_json() const;
  166. };
  167. struct server_task_result {
  168. int id = -1;
  169. int id_slot = -1;
  170. virtual bool is_error() {
  171. // only used by server_task_result_error
  172. return false;
  173. }
  174. virtual bool is_stop() {
  175. // only used by server_task_result_cmpl_*
  176. return true;
  177. }
  178. virtual int get_index() {
  179. return -1;
  180. }
  181. virtual void update(task_result_state &) {
  182. // only used by server_task_result_cmpl_*
  183. }
  184. virtual json to_json() = 0;
  185. virtual ~server_task_result() = default;
  186. };
  187. // using shared_ptr for polymorphism of server_task_result
  188. using server_task_result_ptr = std::unique_ptr<server_task_result>;
  189. struct completion_token_output {
  190. llama_token tok;
  191. float prob;
  192. std::string text_to_send;
  193. struct prob_info {
  194. llama_token tok;
  195. std::string txt;
  196. float prob;
  197. };
  198. std::vector<prob_info> probs;
  199. json to_json(bool post_sampling_probs) const;
  200. static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
  201. static float logarithm(float x);
  202. static std::vector<unsigned char> str_to_bytes(const std::string & str);
  203. };
  204. struct server_task_result_cmpl_final : server_task_result {
  205. int index = 0;
  206. std::string content;
  207. llama_tokens tokens;
  208. bool stream;
  209. bool include_usage;
  210. result_timings timings;
  211. std::string prompt;
  212. bool truncated;
  213. int32_t n_decoded;
  214. int32_t n_prompt_tokens;
  215. int32_t n_tokens_cached;
  216. bool has_new_line;
  217. std::string stopping_word;
  218. stop_type stop = STOP_TYPE_NONE;
  219. bool post_sampling_probs;
  220. std::vector<completion_token_output> probs_output;
  221. std::vector<std::string> response_fields;
  222. task_params generation_params;
  223. // response formatting
  224. bool verbose = false;
  225. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  226. std::string oaicompat_model;
  227. std::string oaicompat_cmpl_id;
  228. common_chat_msg oaicompat_msg; // to be populated by update()
  229. std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
  230. bool is_updated = false;
  231. virtual int get_index() override {
  232. return index;
  233. }
  234. virtual bool is_stop() override {
  235. return true; // in stream mode, final responses are considered stop
  236. }
  237. virtual json to_json() override;
  238. virtual void update(task_result_state & state) override {
  239. is_updated = true;
  240. oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
  241. }
  242. json to_json_non_oaicompat();
  243. json to_json_oaicompat();
  244. json to_json_oaicompat_chat();
  245. json to_json_oaicompat_chat_stream();
  246. json to_json_anthropic();
  247. json to_json_anthropic_stream();
  248. };
  249. struct server_task_result_cmpl_partial : server_task_result {
  250. int index = 0;
  251. std::string content;
  252. llama_tokens tokens;
  253. int32_t n_decoded;
  254. int32_t n_prompt_tokens;
  255. bool post_sampling_probs;
  256. bool is_progress = false;
  257. completion_token_output prob_output;
  258. result_timings timings;
  259. result_prompt_progress progress;
  260. // response formatting
  261. bool verbose = false;
  262. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  263. std::string oaicompat_model;
  264. std::string oaicompat_cmpl_id;
  265. std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
  266. bool is_updated = false;
  267. virtual int get_index() override {
  268. return index;
  269. }
  270. virtual bool is_stop() override {
  271. return false; // in stream mode, partial responses are not considered stop
  272. }
  273. virtual json to_json() override;
  274. virtual void update(task_result_state & state) override {
  275. is_updated = true;
  276. state.update_chat_msg(content, true, oaicompat_msg_diffs);
  277. }
  278. json to_json_non_oaicompat();
  279. json to_json_oaicompat();
  280. json to_json_oaicompat_chat();
  281. json to_json_anthropic();
  282. };
  283. struct server_task_result_embd : server_task_result {
  284. int index = 0;
  285. std::vector<std::vector<float>> embedding;
  286. int32_t n_tokens;
  287. // response formatting
  288. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  289. virtual int get_index() override {
  290. return index;
  291. }
  292. virtual json to_json() override;
  293. json to_json_non_oaicompat();
  294. json to_json_oaicompat();
  295. };
  296. struct server_task_result_rerank : server_task_result {
  297. int index = 0;
  298. float score = -1e6;
  299. int32_t n_tokens;
  300. virtual int get_index() override {
  301. return index;
  302. }
  303. virtual json to_json() override;
  304. };
  305. struct server_task_result_error : server_task_result {
  306. int index = 0;
  307. error_type err_type = ERROR_TYPE_SERVER;
  308. std::string err_msg;
  309. // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
  310. int32_t n_prompt_tokens = 0;
  311. int32_t n_ctx = 0;
  312. virtual bool is_error() override {
  313. return true;
  314. }
  315. virtual json to_json() override;
  316. };
  317. struct server_task_result_metrics : server_task_result {
  318. int n_idle_slots;
  319. int n_processing_slots;
  320. int n_tasks_deferred;
  321. int64_t t_start;
  322. // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
  323. uint64_t n_prompt_tokens_processed_total = 0;
  324. uint64_t t_prompt_processing_total = 0;
  325. uint64_t n_tokens_predicted_total = 0;
  326. uint64_t t_tokens_generation_total = 0;
  327. uint64_t n_tokens_max = 0;
  328. uint64_t n_prompt_tokens_processed = 0;
  329. uint64_t t_prompt_processing = 0;
  330. uint64_t n_tokens_predicted = 0;
  331. uint64_t t_tokens_generation = 0;
  332. uint64_t n_decode_total = 0;
  333. uint64_t n_busy_slots_total = 0;
  334. // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
  335. // therefore, we use json to temporarily store the slot.to_json() result
  336. json slots_data = json::array();
  337. virtual json to_json() override;
  338. };
  339. struct server_task_result_slot_save_load : server_task_result {
  340. std::string filename;
  341. bool is_save; // true = save, false = load
  342. size_t n_tokens;
  343. size_t n_bytes;
  344. double t_ms;
  345. virtual json to_json() override;
  346. };
  347. struct server_task_result_slot_erase : server_task_result {
  348. size_t n_erased;
  349. virtual json to_json() override;
  350. };
  351. struct server_task_result_apply_lora : server_task_result {
  352. virtual json to_json() override;
  353. };
  354. struct server_prompt_checkpoint {
  355. llama_pos pos_min;
  356. llama_pos pos_max;
  357. std::vector<uint8_t> data;
  358. size_t size() const {
  359. return data.size();
  360. }
  361. };
  362. struct server_prompt {
  363. server_tokens tokens;
  364. std::vector<uint8_t> data;
  365. std::list<server_prompt_checkpoint> checkpoints;
  366. size_t size() const {
  367. size_t res = data.size();
  368. for (const auto & checkpoint : checkpoints) {
  369. res += checkpoint.size();
  370. }
  371. return res;
  372. }
  373. int n_tokens() const {
  374. return tokens.size();
  375. }
  376. server_prompt clone() const {
  377. return server_prompt {
  378. tokens.clone(),
  379. data,
  380. checkpoints
  381. };
  382. }
  383. };
  384. struct server_prompt_cache {
  385. server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
  386. this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
  387. this->limit_tokens = limit_tokens;
  388. }
  389. std::list<server_prompt> states;
  390. // in bytes, 0 = no limit
  391. size_t limit_size = 0;
  392. // in tokens, 0 = no limit
  393. size_t limit_tokens = 0;
  394. size_t size() const;
  395. size_t n_tokens() const;
  396. server_prompt * alloc(const server_prompt & prompt, size_t state_size);
  397. bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
  398. void update();
  399. };