server-task.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. #pragma once
  2. #include "common.h"
  3. #include "llama.h"
  4. #include <string>
  5. #include <unordered_set>
  6. #include <list>
  7. #include <map>
  8. // TODO: prevent including the whole server-common.h as we only use server_tokens
  9. #include "server-common.h"
  10. using json = nlohmann::ordered_json;
  11. enum server_task_type {
  12. SERVER_TASK_TYPE_COMPLETION,
  13. SERVER_TASK_TYPE_EMBEDDING,
  14. SERVER_TASK_TYPE_RERANK,
  15. SERVER_TASK_TYPE_INFILL,
  16. SERVER_TASK_TYPE_CANCEL,
  17. SERVER_TASK_TYPE_NEXT_RESPONSE,
  18. SERVER_TASK_TYPE_METRICS,
  19. SERVER_TASK_TYPE_SLOT_SAVE,
  20. SERVER_TASK_TYPE_SLOT_RESTORE,
  21. SERVER_TASK_TYPE_SLOT_ERASE,
  22. SERVER_TASK_TYPE_GET_LORA,
  23. SERVER_TASK_TYPE_SET_LORA,
  24. };
  25. // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
  26. enum task_response_type {
  27. TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
  28. TASK_RESPONSE_TYPE_OAI_CHAT,
  29. TASK_RESPONSE_TYPE_OAI_CMPL,
  30. TASK_RESPONSE_TYPE_OAI_RESP,
  31. TASK_RESPONSE_TYPE_OAI_EMBD,
  32. TASK_RESPONSE_TYPE_ANTHROPIC,
  33. };
  34. enum stop_type {
  35. STOP_TYPE_NONE,
  36. STOP_TYPE_EOS,
  37. STOP_TYPE_WORD,
  38. STOP_TYPE_LIMIT,
  39. };
  40. struct task_params {
  41. bool stream = true;
  42. bool include_usage = false;
  43. bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
  44. bool return_tokens = false;
  45. bool return_progress = false;
  46. int32_t n_keep = 0; // number of tokens to keep from initial prompt
  47. int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
  48. int32_t n_predict = -1; // new tokens to predict
  49. int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
  50. int32_t n_cmpl = 1; // number of completions to generate from this prompt
  51. int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
  52. int64_t t_max_prompt_ms = -1; // TODO: implement
  53. int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
  54. std::map<int, float> lora; // mapping adapter ID -> scale
  55. std::vector<std::string> antiprompt;
  56. std::vector<std::string> response_fields;
  57. bool timings_per_token = false;
  58. bool post_sampling_probs = false;
  59. struct common_params_sampling sampling;
  60. struct common_params_speculative speculative;
  61. // response formatting
  62. bool verbose = false;
  63. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  64. std::string oaicompat_model;
  65. std::string oaicompat_cmpl_id;
  66. // per-request parameters for chat parsing
  67. common_chat_parser_params chat_parser_params;
  68. // Embeddings
  69. int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
  70. json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
  71. json to_json(bool only_metrics = false) const;
  72. };
  73. // struct for tracking the state of a task (e.g., for streaming)
  74. struct task_result_state {
  75. // tracking diffs for partial tool calls
  76. std::vector<common_chat_msg_diff> diffs;
  77. common_chat_parser_params chat_parser_params;
  78. common_chat_msg chat_msg;
  79. std::string generated_text; // append new chunks of generated text here
  80. std::vector<std::string> generated_tool_call_ids;
  81. // for OpenAI Responses and Anthropic streaming API:
  82. // track output item / content block state across chunks
  83. bool thinking_block_started = false;
  84. bool text_block_started = false;
  85. // for OpenAI Responses streaming API
  86. const std::string oai_resp_id;
  87. const std::string oai_resp_reasoning_id;
  88. const std::string oai_resp_message_id;
  89. std::string oai_resp_fc_id; // function call ID for current args delta
  90. task_result_state(const common_chat_parser_params & chat_parser_params)
  91. : chat_parser_params(chat_parser_params)
  92. , oai_resp_id("resp_" + random_string())
  93. , oai_resp_reasoning_id("rs_" + random_string())
  94. , oai_resp_message_id("msg_" + random_string()) {}
  95. // parse partial tool calls and update the internal state
  96. common_chat_msg update_chat_msg(
  97. const std::string & text_added,
  98. bool is_partial,
  99. std::vector<common_chat_msg_diff> & diffs);
  100. };
  101. struct server_task {
  102. int id = -1; // to be filled by server_queue
  103. // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
  104. size_t index = 0; // used when there are multiple prompts (batch request)
  105. // used by SERVER_TASK_TYPE_CANCEL
  106. int id_target = -1;
  107. int id_slot = -1;
  108. // used by parallel sampling (multiple completions from same prompt)
  109. int id_parent = -1;
  110. // temporary store of child tasks for scheduling
  111. // note: accessing to elements is invalid after the task is moved to server_slot
  112. std::vector<server_task> child_tasks;
  113. // used by SERVER_TASK_TYPE_INFERENCE
  114. task_params params;
  115. server_tokens tokens;
  116. // only used by CLI, this allow tokenizing CLI inputs on server side
  117. // we need this because mtmd_context and vocab are not accessible outside of server_context
  118. bool cli = false;
  119. std::string cli_prompt;
  120. std::vector<raw_buffer> cli_files;
  121. server_task_type type;
  122. // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
  123. struct slot_action {
  124. int slot_id;
  125. std::string filename;
  126. std::string filepath;
  127. };
  128. slot_action slot_action;
  129. // used by SERVER_TASK_TYPE_METRICS
  130. bool metrics_reset_bucket = false;
  131. // used by SERVER_TASK_TYPE_SET_LORA
  132. std::map<int, float> set_lora; // mapping adapter ID -> scale
  133. server_task() = default;
  134. server_task(server_task_type type) : type(type) {}
  135. int32_t n_tokens() const {
  136. return tokens.size();
  137. }
  138. bool need_embd() const {
  139. switch (type) {
  140. case SERVER_TASK_TYPE_EMBEDDING:
  141. case SERVER_TASK_TYPE_RERANK:
  142. return true;
  143. default:
  144. return false;
  145. }
  146. }
  147. bool need_logits() const {
  148. switch (type) {
  149. case SERVER_TASK_TYPE_COMPLETION:
  150. case SERVER_TASK_TYPE_INFILL:
  151. return true;
  152. default:
  153. return false;
  154. }
  155. }
  156. bool need_sampling() const {
  157. switch (type) {
  158. case SERVER_TASK_TYPE_COMPLETION:
  159. case SERVER_TASK_TYPE_INFILL:
  160. return true;
  161. default:
  162. return false;
  163. }
  164. }
  165. static task_params params_from_json_cmpl(
  166. const llama_vocab * vocab,
  167. const common_params & params_base,
  168. const int n_ctx_slot,
  169. const json & data);
  170. // utility function
  171. static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
  172. std::unordered_set<int> ids(tasks.size());
  173. for (size_t i = 0; i < tasks.size(); i++) {
  174. ids.insert(tasks[i].id);
  175. for (auto & child : tasks[i].child_tasks) {
  176. ids.insert(child.id);
  177. }
  178. }
  179. return ids;
  180. }
  181. void add_child(int id_parent, int id_child) {
  182. server_task copy;
  183. copy.id = id_child;
  184. copy.id_parent = id_parent;
  185. copy.params = params;
  186. copy.type = type;
  187. copy.tokens = tokens.clone();
  188. copy.id_slot = -1; // child tasks cannot specify slot
  189. // use different sampling seed for each child
  190. // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
  191. if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
  192. copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
  193. }
  194. child_tasks.push_back(std::move(copy));
  195. }
  196. // the task will be moved into queue, then onto slots
  197. // however, the state must be kept by caller (e.g., HTTP thread)
  198. task_result_state create_state() const {
  199. return task_result_state(params.chat_parser_params);
  200. }
  201. bool is_parent() const {
  202. return child_tasks.size() > 0;
  203. }
  204. bool is_child() const {
  205. return id_parent != -1;
  206. }
  207. };
  208. struct result_timings {
  209. int32_t cache_n = -1;
  210. int32_t prompt_n = -1;
  211. double prompt_ms;
  212. double prompt_per_token_ms;
  213. double prompt_per_second;
  214. int32_t predicted_n = -1;
  215. double predicted_ms;
  216. double predicted_per_token_ms;
  217. double predicted_per_second;
  218. // Optional speculative metrics - only included when > 0
  219. int32_t draft_n = 0;
  220. int32_t draft_n_accepted = 0;
  221. json to_json() const;
  222. };
  223. struct result_prompt_progress {
  224. int32_t total = 0;
  225. int32_t cache = 0;
  226. int32_t processed = 0;
  227. int64_t time_ms = 0;
  228. json to_json() const;
  229. };
  230. struct server_task_result {
  231. int id = -1;
  232. int id_slot = -1;
  233. // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
  234. size_t index = 0; // to be used for batched tasks
  235. virtual bool is_error() {
  236. // only used by server_task_result_error
  237. return false;
  238. }
  239. virtual bool is_stop() {
  240. // only used by server_task_result_cmpl_*
  241. return true;
  242. }
  243. virtual void update(task_result_state &) {
  244. // only used by server_task_result_cmpl_*
  245. }
  246. virtual json to_json() = 0;
  247. virtual ~server_task_result() = default;
  248. };
  249. // using shared_ptr for polymorphism of server_task_result
  250. using server_task_result_ptr = std::unique_ptr<server_task_result>;
  251. struct completion_token_output {
  252. llama_token tok;
  253. float prob;
  254. std::string text_to_send;
  255. struct prob_info {
  256. llama_token tok;
  257. std::string txt;
  258. float prob;
  259. };
  260. std::vector<prob_info> probs;
  261. json to_json(bool post_sampling_probs) const;
  262. static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
  263. static float logarithm(float x);
  264. static std::vector<unsigned char> str_to_bytes(const std::string & str);
  265. };
  266. struct server_task_result_cmpl_final : server_task_result {
  267. std::string content;
  268. llama_tokens tokens;
  269. bool stream;
  270. bool include_usage;
  271. result_timings timings;
  272. std::string prompt;
  273. bool truncated;
  274. int32_t n_decoded;
  275. int32_t n_prompt_tokens;
  276. int32_t n_tokens_cached;
  277. bool has_new_line;
  278. std::string stopping_word;
  279. stop_type stop = STOP_TYPE_NONE;
  280. bool post_sampling_probs;
  281. std::vector<completion_token_output> probs_output;
  282. std::vector<std::string> response_fields;
  283. task_params generation_params;
  284. // response formatting
  285. bool verbose = false;
  286. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  287. std::string oaicompat_model;
  288. std::string oaicompat_cmpl_id;
  289. common_chat_msg oaicompat_msg; // to be populated by update()
  290. std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
  291. bool is_updated = false;
  292. // for OpenAI Responses API
  293. std::string oai_resp_id;
  294. std::string oai_resp_reasoning_id;
  295. std::string oai_resp_message_id;
  296. virtual bool is_stop() override {
  297. return true; // in stream mode, final responses are considered stop
  298. }
  299. virtual json to_json() override;
  300. virtual void update(task_result_state & state) override {
  301. is_updated = true;
  302. oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
  303. oai_resp_id = state.oai_resp_id;
  304. oai_resp_reasoning_id = state.oai_resp_reasoning_id;
  305. oai_resp_message_id = state.oai_resp_message_id;
  306. }
  307. json to_json_non_oaicompat();
  308. json to_json_oaicompat();
  309. json to_json_oaicompat_chat();
  310. json to_json_oaicompat_chat_stream();
  311. json to_json_oaicompat_resp();
  312. json to_json_oaicompat_resp_stream();
  313. json to_json_anthropic();
  314. json to_json_anthropic_stream();
  315. };
  316. struct server_task_result_cmpl_partial : server_task_result {
  317. std::string content;
  318. llama_tokens tokens;
  319. int32_t n_decoded;
  320. int32_t n_prompt_tokens;
  321. bool post_sampling_probs;
  322. bool is_progress = false;
  323. completion_token_output prob_output;
  324. result_timings timings;
  325. result_prompt_progress progress;
  326. // response formatting
  327. bool verbose = false;
  328. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  329. std::string oaicompat_model;
  330. std::string oaicompat_cmpl_id;
  331. std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
  332. bool is_updated = false;
  333. // Streaming state copied from task_result_state for this chunk
  334. bool thinking_block_started = false;
  335. bool text_block_started = false;
  336. // for OpenAI Responses API
  337. std::string oai_resp_id;
  338. std::string oai_resp_reasoning_id;
  339. std::string oai_resp_message_id;
  340. std::string oai_resp_fc_id;
  341. // for Anthropic API: track if any reasoning content has been generated
  342. bool anthropic_has_reasoning = false;
  343. virtual bool is_stop() override {
  344. return false; // in stream mode, partial responses are not considered stop
  345. }
  346. virtual void update(task_result_state & state) override;
  347. virtual json to_json() override;
  348. json to_json_non_oaicompat();
  349. json to_json_oaicompat();
  350. json to_json_oaicompat_chat();
  351. json to_json_oaicompat_resp();
  352. json to_json_anthropic();
  353. };
  354. struct server_task_result_embd : server_task_result {
  355. std::vector<std::vector<float>> embedding;
  356. int32_t n_tokens;
  357. // response formatting
  358. task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
  359. virtual json to_json() override;
  360. json to_json_non_oaicompat();
  361. json to_json_oaicompat();
  362. };
  363. struct server_task_result_rerank : server_task_result {
  364. float score = -1e6;
  365. int32_t n_tokens;
  366. virtual json to_json() override;
  367. };
  368. struct server_task_result_error : server_task_result {
  369. error_type err_type = ERROR_TYPE_SERVER;
  370. std::string err_msg;
  371. // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
  372. int32_t n_prompt_tokens = 0;
  373. int32_t n_ctx = 0;
  374. virtual bool is_error() override {
  375. return true;
  376. }
  377. virtual json to_json() override;
  378. };
  379. struct server_task_result_metrics : server_task_result {
  380. int n_idle_slots;
  381. int n_processing_slots;
  382. int n_tasks_deferred;
  383. int64_t t_start;
  384. // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
  385. uint64_t n_prompt_tokens_processed_total = 0;
  386. uint64_t t_prompt_processing_total = 0;
  387. uint64_t n_tokens_predicted_total = 0;
  388. uint64_t t_tokens_generation_total = 0;
  389. uint64_t n_tokens_max = 0;
  390. uint64_t n_prompt_tokens_processed = 0;
  391. uint64_t t_prompt_processing = 0;
  392. uint64_t n_tokens_predicted = 0;
  393. uint64_t t_tokens_generation = 0;
  394. uint64_t n_decode_total = 0;
  395. uint64_t n_busy_slots_total = 0;
  396. // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
  397. // therefore, we use json to temporarily store the slot.to_json() result
  398. json slots_data = json::array();
  399. virtual json to_json() override;
  400. };
  401. struct server_task_result_slot_save_load : server_task_result {
  402. std::string filename;
  403. bool is_save; // true = save, false = load
  404. size_t n_tokens;
  405. size_t n_bytes;
  406. double t_ms;
  407. virtual json to_json() override;
  408. };
  409. struct server_task_result_slot_erase : server_task_result {
  410. size_t n_erased;
  411. virtual json to_json() override;
  412. };
  413. struct server_task_result_get_lora : server_task_result {
  414. struct lora {
  415. common_adapter_lora_info info;
  416. std::string alora_invocation_string;
  417. llama_tokens alora_invocation_tokens;
  418. };
  419. std::vector<lora> loras;
  420. virtual json to_json() override;
  421. };
  422. struct server_task_result_apply_lora : server_task_result {
  423. virtual json to_json() override;
  424. };
  425. struct server_prompt_checkpoint {
  426. llama_pos pos_min;
  427. llama_pos pos_max;
  428. std::vector<uint8_t> data;
  429. size_t size() const {
  430. return data.size();
  431. }
  432. };
  433. struct server_prompt {
  434. server_tokens tokens;
  435. std::vector<uint8_t> data;
  436. std::list<server_prompt_checkpoint> checkpoints;
  437. size_t size() const {
  438. size_t res = data.size();
  439. for (const auto & checkpoint : checkpoints) {
  440. res += checkpoint.size();
  441. }
  442. return res;
  443. }
  444. int n_tokens() const {
  445. return tokens.size();
  446. }
  447. server_prompt clone() const {
  448. return server_prompt {
  449. tokens.clone(),
  450. data,
  451. checkpoints
  452. };
  453. }
  454. };
  455. struct server_prompt_cache {
  456. server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
  457. this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
  458. this->limit_tokens = limit_tokens;
  459. }
  460. std::list<server_prompt> states;
  461. // in bytes, 0 = no limit
  462. size_t limit_size = 0;
  463. // in tokens, 0 = no limit
  464. size_t limit_tokens = 0;
  465. size_t size() const;
  466. size_t n_tokens() const;
  467. server_prompt * alloc(const server_prompt & prompt, size_t state_size);
  468. bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
  469. void update();
  470. };