1
0

server-common.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. #pragma once
  2. #include "common.h"
  3. #include "log.h"
  4. #include "llama.h"
  5. #include "chat.h"
  6. #include "mtmd.h"
  7. #define JSON_ASSERT GGML_ASSERT
  8. #include <nlohmann/json.hpp>
  9. #include <string>
  10. #include <vector>
  11. #include <cinttypes>
  12. const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
  13. using json = nlohmann::ordered_json;
  14. #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
  15. #define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
  16. #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
  17. #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
  18. #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
  19. #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  20. #define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
  21. #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  22. #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  23. #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  24. using raw_buffer = std::vector<uint8_t>;
  25. template <typename T>
  26. static T json_value(const json & body, const std::string & key, const T & default_value) {
  27. // Fallback null to default value
  28. if (body.contains(key) && !body.at(key).is_null()) {
  29. try {
  30. return body.at(key);
  31. } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
  32. LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
  33. return default_value;
  34. }
  35. } else {
  36. return default_value;
  37. }
  38. }
  39. // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
  40. enum error_type {
  41. ERROR_TYPE_INVALID_REQUEST,
  42. ERROR_TYPE_AUTHENTICATION,
  43. ERROR_TYPE_SERVER,
  44. ERROR_TYPE_NOT_FOUND,
  45. ERROR_TYPE_PERMISSION,
  46. ERROR_TYPE_UNAVAILABLE, // custom error
  47. ERROR_TYPE_NOT_SUPPORTED, // custom error
  48. ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
  49. };
  50. // thin wrapper around common_grammar_trigger with (de)serialization functions
  51. struct server_grammar_trigger {
  52. common_grammar_trigger value;
  53. server_grammar_trigger() = default;
  54. server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
  55. server_grammar_trigger(const json & in) {
  56. value.type = (common_grammar_trigger_type) in.at("type").get<int>();
  57. value.value = in.at("value").get<std::string>();
  58. if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
  59. value.token = (llama_token) in.at("token").get<int>();
  60. }
  61. }
  62. json to_json() const {
  63. json out {
  64. {"type", (int) value.type},
  65. {"value", value.value},
  66. };
  67. if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
  68. out["token"] = (int) value.token;
  69. }
  70. return out;
  71. }
  72. };
  73. json format_error_response(const std::string & message, const enum error_type type);
  74. //
  75. // random string / id
  76. //
  77. std::string random_string();
  78. std::string gen_chatcmplid();
  79. std::string gen_tool_call_id();
  80. //
  81. // lora utils
  82. //
  83. // check whether the given lora set has only aloras activated (empty => false)
  84. bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
  85. // if the two sets of loras are different, they require a cache clear unless the
  86. // change is only from aloras to aloras.
  87. bool lora_should_clear_cache(
  88. const std::vector<common_adapter_lora_info> & current,
  89. const std::vector<common_adapter_lora_info> & next);
  90. std::vector<common_adapter_lora_info> parse_lora_request(
  91. const std::vector<common_adapter_lora_info> & lora_base,
  92. const json & data);
  93. bool are_lora_equal(
  94. const std::vector<common_adapter_lora_info> & l1,
  95. const std::vector<common_adapter_lora_info> & l2);
  96. // get the ids of all enabled loras
  97. std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
  98. //
  99. // server_tokens
  100. //
  101. /**
  102. * server_tokens is a helper to manage the input tokens and image for the server.
  103. * it is made this way to simplify the logic of KV cache management.
  104. */
  105. struct server_tokens {
  106. bool has_mtmd = false;
  107. private: // disallow accessing these members directly, risking out-of-sync
  108. // map a **start** index in tokens to the image chunk
  109. // note: the order need to be in-sync with tokens
  110. std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
  111. // list of tokens
  112. // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
  113. // otherwise, it is a normal text token
  114. // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
  115. // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
  116. llama_tokens tokens;
  117. // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
  118. // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
  119. // idx 0 1 2 3 4 5 6 7 8 9 10
  120. // pos 0 1 2 3 4 5 5 5 7 7 7
  121. // map_idx_to_media will contain: {5, img0}, {8, img1}
  122. public:
  123. server_tokens() = default;
  124. ~server_tokens() = default;
  125. // Prevent copying
  126. // TODO: server_tokens should be copyable - remove this:
  127. server_tokens(const server_tokens&) = delete;
  128. server_tokens& operator=(const server_tokens&) = delete;
  129. // Allow moving (usually implicitly generated if members are movable)
  130. server_tokens(server_tokens&&) = default;
  131. server_tokens& operator=(server_tokens&&) = default;
  132. // Allow accessing elements using [] operator
  133. llama_token operator[](size_t index) { return tokens[index]; }
  134. const llama_token& operator[](size_t index) const { return tokens[index]; }
  135. server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
  136. server_tokens(const llama_tokens & tokens, bool has_mtmd);
  137. // for debugging
  138. std::string str() const;
  139. llama_pos pos_next() const;
  140. const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
  141. void push_back(llama_token tok);
  142. // will create a copy of the chunk if it contains non-text data
  143. void push_back(const mtmd_input_chunk * chunk);
  144. // appends server tokens, updates the media map. copies media chunks.
  145. void push_back(server_tokens & tokens);
  146. // for compatibility with context shift and prompt truncation
  147. void insert(const llama_tokens & inp_tokens);
  148. // for compatibility with speculative decoding, ctx shift, slot save/load
  149. const llama_tokens & get_text_tokens() const;
  150. // for compatibility with speculative decoding
  151. void set_token(llama_pos pos, llama_token id);
  152. size_t size() const { return tokens.size(); }
  153. bool empty() const { return tokens.empty(); }
  154. void clear() {
  155. map_idx_to_media.clear();
  156. tokens.clear();
  157. }
  158. void keep_first(size_t n);
  159. std::string detokenize(const llama_context * ctx, bool special) const;
  160. size_t get_common_prefix(const server_tokens & b) const;
  161. // make sure all text tokens are within the vocab range
  162. bool validate(const struct llama_context * ctx) const;
  163. // encode and decode the image chunk
  164. int32_t process_chunk(
  165. llama_context * ctx,
  166. mtmd_context * mctx,
  167. size_t idx,
  168. llama_pos pos,
  169. int32_t seq_id,
  170. size_t & n_tokens_out) const;
  171. server_tokens clone() const;
  172. };
  173. //
  174. // tokenizer and input processing utils
  175. //
  176. bool json_is_array_of_numbers(const json & data);
  177. // is array having BOTH numbers & strings?
  178. bool json_is_array_of_mixed_numbers_strings(const json & data);
  179. // does array have any individual integers/tokens?
  180. bool json_is_array_and_contains_numbers(const json & data);
  181. // get value by path(key1 / key2)
  182. json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
  183. /**
  184. * this handles 2 cases:
  185. * - only string, example: "string"
  186. * - mixed string and tokens, example: [12, 34, "string", 56, 78]
  187. */
  188. llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
  189. // return the last index of character that can form a valid string
  190. // if the last character is potentially cut in half, return the index before the cut
  191. // if validate_utf8(text) == text.size(), then the whole text is valid utf8
  192. size_t validate_utf8(const std::string& text);
  193. // process mtmd prompt, return the server_tokens containing both text tokens and media chunks
  194. server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
  195. /**
  196. * break the input "prompt" object into multiple prompt if needed, then tokenize them
  197. * this supports these cases:
  198. * - "prompt": "string"
  199. * - "prompt": [12, 34, 56]
  200. * - "prompt": [12, 34, "string", 56, 78]
  201. * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
  202. * and multiple prompts (multi-tasks):
  203. * - "prompt": ["string1", "string2"]
  204. * - "prompt": ["string1", [12, 34, 56]]
  205. * - "prompt": [[12, 34, 56], [78, 90, 12]]
  206. * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
  207. */
  208. std::vector<server_tokens> tokenize_input_prompts(
  209. const llama_vocab * vocab,
  210. mtmd_context * mctx,
  211. const json & json_prompt,
  212. bool add_special,
  213. bool parse_special);
  214. //
  215. // OAI utils
  216. //
  217. // used by /completions endpoint
  218. json oaicompat_completion_params_parse(const json & body);
  219. struct oaicompat_parser_options {
  220. bool use_jinja;
  221. bool prefill_assistant;
  222. common_reasoning_format reasoning_format;
  223. std::map<std::string,std::string> chat_template_kwargs;
  224. common_chat_templates * tmpls;
  225. bool allow_image;
  226. bool allow_audio;
  227. bool enable_thinking = true;
  228. std::string media_path;
  229. };
  230. // used by /chat/completions endpoint
  231. json oaicompat_chat_params_parse(
  232. json & body, /* openai api json semantics */
  233. const oaicompat_parser_options & opt,
  234. std::vector<raw_buffer> & out_files);
  235. // convert Anthropic Messages API format to OpenAI Chat Completions API format
  236. json convert_anthropic_to_oai(const json & body);
  237. // TODO: move it to server-task.cpp
  238. json format_embeddings_response_oaicompat(
  239. const json & request,
  240. const std::string & model_name,
  241. const json & embeddings,
  242. bool use_base64 = false);
  243. // TODO: move it to server-task.cpp
  244. json format_response_rerank(
  245. const json & request,
  246. const std::string & model_name,
  247. const json & ranks,
  248. bool is_tei_format,
  249. std::vector<std::string> & texts,
  250. int top_n);
  251. //
  252. // other utils
  253. //
  254. std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
  255. std::string safe_json_to_str(const json & data);
  256. std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
  257. // format incomplete utf-8 multibyte character for output
  258. std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
  259. // format server-sent event (SSE), return the formatted string to send
  260. // note: if data is a json array, it will be sent as multiple events, one per item
  261. std::string format_oai_sse(const json & data);
  262. // format Anthropic-style SSE with event types
  263. std::string format_anthropic_sse(const json & data);
  264. bool is_valid_utf8(const std::string & str);
  265. //
  266. // formatting output responses
  267. // TODO: move these to server-task.cpp
  268. //
  269. llama_tokens format_prompt_infill(
  270. const llama_vocab * vocab,
  271. const json & input_prefix,
  272. const json & input_suffix,
  273. const json & input_extra,
  274. const int n_batch,
  275. const int n_predict,
  276. const int n_ctx,
  277. const bool spm_infill,
  278. const llama_tokens & tokens_prompt);
  279. // format rerank task: [BOS]query[EOS][SEP]doc[EOS].
  280. server_tokens format_prompt_rerank(
  281. const struct llama_model * model,
  282. const struct llama_vocab * vocab,
  283. mtmd_context * mctx,
  284. const std::string & query,
  285. const std::string & doc);