chat.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
  2. #pragma once
  3. #include "common.h"
  4. #include "peg-parser.h"
  5. #include <functional>
  6. #include <chrono>
  7. #include <string>
  8. #include <vector>
  9. #include <map>
  10. #include <nlohmann/json_fwd.hpp>
  11. struct common_chat_templates;
  12. struct common_chat_tool_call {
  13. std::string name;
  14. std::string arguments;
  15. std::string id;
  16. bool operator==(const common_chat_tool_call & other) const {
  17. return name == other.name && arguments == other.arguments && id == other.id;
  18. }
  19. };
  20. struct common_chat_msg_content_part {
  21. std::string type;
  22. std::string text;
  23. // TODO @ngxson : no known chat templates support reasoning_content in content parts yet
  24. // this can be useful for models with interleaved thinking (like Kimi-K2)
  25. // if you see any templates explicitly support this, please ping me
  26. // std::string reasoning_content;
  27. bool operator==(const common_chat_msg_content_part & other) const {
  28. return type == other.type && text == other.text;
  29. }
  30. };
  31. struct common_chat_msg {
  32. std::string role;
  33. std::string content;
  34. std::vector<common_chat_msg_content_part> content_parts;
  35. std::vector<common_chat_tool_call> tool_calls;
  36. std::string reasoning_content;
  37. std::string tool_name;
  38. std::string tool_call_id;
  39. nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
  40. bool empty() const {
  41. return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
  42. }
  43. void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
  44. for (auto i = 0u; i < tool_calls.size(); i++) {
  45. if (ids_cache.size() <= i) {
  46. auto id = tool_calls[i].id;
  47. if (id.empty()) {
  48. id = gen_tool_call_id();
  49. }
  50. ids_cache.push_back(id);
  51. }
  52. tool_calls[i].id = ids_cache[i];
  53. }
  54. }
  55. bool operator==(const common_chat_msg & other) const {
  56. return role == other.role
  57. && content == other.content
  58. && content_parts == other.content_parts
  59. && tool_calls == other.tool_calls
  60. && reasoning_content == other.reasoning_content
  61. && tool_name == other.tool_name
  62. && tool_call_id == other.tool_call_id;
  63. }
  64. bool operator!=(const common_chat_msg & other) const {
  65. return !(*this == other);
  66. }
  67. };
  68. struct common_chat_msg_diff {
  69. std::string reasoning_content_delta;
  70. std::string content_delta;
  71. size_t tool_call_index = std::string::npos;
  72. common_chat_tool_call tool_call_delta;
  73. static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
  74. bool operator==(const common_chat_msg_diff & other) const {
  75. return content_delta == other.content_delta
  76. && tool_call_index == other.tool_call_index
  77. && tool_call_delta == other.tool_call_delta;
  78. }
  79. };
  80. struct common_chat_tool {
  81. std::string name;
  82. std::string description;
  83. std::string parameters;
  84. };
  85. enum common_chat_tool_choice {
  86. COMMON_CHAT_TOOL_CHOICE_AUTO,
  87. COMMON_CHAT_TOOL_CHOICE_REQUIRED,
  88. COMMON_CHAT_TOOL_CHOICE_NONE,
  89. };
  90. enum common_chat_format {
  91. COMMON_CHAT_FORMAT_CONTENT_ONLY,
  92. COMMON_CHAT_FORMAT_GENERIC,
  93. COMMON_CHAT_FORMAT_MISTRAL_NEMO,
  94. COMMON_CHAT_FORMAT_MAGISTRAL,
  95. COMMON_CHAT_FORMAT_LLAMA_3_X,
  96. COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  97. COMMON_CHAT_FORMAT_DEEPSEEK_R1,
  98. COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
  99. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
  100. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  101. COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
  102. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  103. COMMON_CHAT_FORMAT_COMMAND_R7B,
  104. COMMON_CHAT_FORMAT_GRANITE,
  105. COMMON_CHAT_FORMAT_GPT_OSS,
  106. COMMON_CHAT_FORMAT_SEED_OSS,
  107. COMMON_CHAT_FORMAT_NEMOTRON_V2,
  108. COMMON_CHAT_FORMAT_APERTUS,
  109. COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
  110. COMMON_CHAT_FORMAT_GLM_4_5,
  111. COMMON_CHAT_FORMAT_MINIMAX_M2,
  112. COMMON_CHAT_FORMAT_KIMI_K2,
  113. COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
  114. COMMON_CHAT_FORMAT_APRIEL_1_5,
  115. COMMON_CHAT_FORMAT_XIAOMI_MIMO,
  116. COMMON_CHAT_FORMAT_SOLAR_OPEN,
  117. COMMON_CHAT_FORMAT_EXAONE_MOE,
  118. // These are intended to be parsed by the PEG parser
  119. COMMON_CHAT_FORMAT_PEG_SIMPLE,
  120. COMMON_CHAT_FORMAT_PEG_NATIVE,
  121. COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
  122. COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
  123. };
  124. struct common_chat_templates_inputs {
  125. std::vector<common_chat_msg> messages;
  126. std::string grammar;
  127. std::string json_schema;
  128. bool add_generation_prompt = true;
  129. bool use_jinja = true;
  130. // Parameters below only supported when use_jinja is true
  131. std::vector<common_chat_tool> tools;
  132. common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
  133. bool parallel_tool_calls = false;
  134. common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
  135. bool enable_thinking = true;
  136. std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
  137. std::map<std::string, std::string> chat_template_kwargs;
  138. bool add_bos = false;
  139. bool add_eos = false;
  140. };
  141. struct common_chat_params {
  142. common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  143. std::string prompt;
  144. std::string grammar;
  145. bool grammar_lazy = false;
  146. bool thinking_forced_open = false;
  147. std::vector<common_grammar_trigger> grammar_triggers;
  148. std::vector<std::string> preserved_tokens;
  149. std::vector<std::string> additional_stops;
  150. std::string parser;
  151. };
  152. // per-message parsing syntax
  153. // should be derived from common_chat_params
  154. struct common_chat_parser_params {
  155. common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  156. common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
  157. // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
  158. bool reasoning_in_content = false;
  159. bool thinking_forced_open = false;
  160. bool parse_tool_calls = true;
  161. common_peg_arena parser = {};
  162. common_chat_parser_params() = default;
  163. common_chat_parser_params(const common_chat_params & chat_params) {
  164. format = chat_params.format;
  165. thinking_forced_open = chat_params.thinking_forced_open;
  166. }
  167. };
  168. // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
  169. bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
  170. void common_chat_templates_free(struct common_chat_templates * tmpls);
  171. struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
  172. typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
  173. common_chat_templates_ptr common_chat_templates_init(
  174. const struct llama_model * model,
  175. const std::string & chat_template_override,
  176. const std::string & bos_token_override = "",
  177. const std::string & eos_token_override = "");
  178. bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
  179. std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
  180. struct common_chat_params common_chat_templates_apply(
  181. const struct common_chat_templates * tmpls,
  182. const struct common_chat_templates_inputs & inputs);
  183. // Format single message, while taking into account the position of that message in chat history
  184. std::string common_chat_format_single(
  185. const struct common_chat_templates * tmpls,
  186. const std::vector<common_chat_msg> & past_msg,
  187. const common_chat_msg & new_msg,
  188. bool add_ass,
  189. bool use_jinja);
  190. // Returns an example of formatted chat
  191. std::string common_chat_format_example(
  192. const struct common_chat_templates * tmpls,
  193. bool use_jinja,
  194. const std::map<std::string, std::string> & chat_template_kwargs);
  195. const char* common_chat_format_name(common_chat_format format);
  196. common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
  197. common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
  198. // used by arg and server
  199. const char * common_reasoning_format_name(common_reasoning_format format);
  200. common_reasoning_format common_reasoning_format_from_name(const std::string & format);
  201. common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
  202. bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
  203. // Parses a JSON array of messages in OpenAI's chat completion API format.
  204. std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
  205. nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
  206. std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
  207. nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
  208. nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
  209. // get template caps, useful for reporting to server /props endpoint
  210. std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);