chat.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
  2. #pragma once
  3. #include "common.h"
  4. #include <chrono>
  5. #include <string>
  6. #include <vector>
  7. struct common_chat_templates;
  8. struct common_chat_tool_call {
  9. std::string name;
  10. std::string arguments;
  11. std::string id;
  12. };
  13. struct common_chat_msg_content_part {
  14. std::string type;
  15. std::string text;
  16. };
  17. struct common_chat_msg {
  18. std::string role;
  19. std::string content;
  20. std::vector<common_chat_msg_content_part> content_parts = {};
  21. std::vector<common_chat_tool_call> tool_calls = {};
  22. std::string reasoning_content;
  23. std::string tool_name;
  24. std::string tool_call_id;
  25. };
  26. struct common_chat_tool {
  27. std::string name;
  28. std::string description;
  29. std::string parameters;
  30. };
  31. enum common_chat_tool_choice {
  32. COMMON_CHAT_TOOL_CHOICE_AUTO,
  33. COMMON_CHAT_TOOL_CHOICE_REQUIRED,
  34. COMMON_CHAT_TOOL_CHOICE_NONE,
  35. };
  36. enum common_chat_format {
  37. COMMON_CHAT_FORMAT_CONTENT_ONLY,
  38. COMMON_CHAT_FORMAT_GENERIC,
  39. COMMON_CHAT_FORMAT_MISTRAL_NEMO,
  40. COMMON_CHAT_FORMAT_LLAMA_3_X,
  41. COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  42. COMMON_CHAT_FORMAT_DEEPSEEK_R1,
  43. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
  44. COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
  45. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
  46. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  47. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  48. COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
  49. COMMON_CHAT_FORMAT_COMMAND_R7B,
  50. COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
  51. COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
  52. };
  53. struct common_chat_templates_inputs {
  54. std::vector<common_chat_msg> messages;
  55. std::string grammar;
  56. std::string json_schema;
  57. bool add_generation_prompt = true;
  58. bool use_jinja = true;
  59. // Parameters below only supported when use_jinja is true
  60. std::vector<common_chat_tool> tools;
  61. common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
  62. bool parallel_tool_calls = false;
  63. bool extract_reasoning = true;
  64. std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
  65. };
  66. struct common_chat_params {
  67. common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  68. std::string prompt;
  69. std::string grammar;
  70. bool grammar_lazy = false;
  71. std::vector<common_grammar_trigger> grammar_triggers;
  72. std::vector<std::string> preserved_tokens;
  73. std::vector<std::string> additional_stops;
  74. };
  75. // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
  76. bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
  77. void common_chat_templates_free(struct common_chat_templates * tmpls);
  78. struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
  79. typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
  80. common_chat_templates_ptr common_chat_templates_init(
  81. const struct llama_model * model,
  82. const std::string & chat_template_override,
  83. const std::string & bos_token_override = "",
  84. const std::string & eos_token_override = "");
  85. bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
  86. const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
  87. struct common_chat_params common_chat_templates_apply(
  88. const struct common_chat_templates * tmpls,
  89. const struct common_chat_templates_inputs & inputs);
  90. // Format single message, while taking into account the position of that message in chat history
  91. std::string common_chat_format_single(
  92. const struct common_chat_templates * tmpls,
  93. const std::vector<common_chat_msg> & past_msg,
  94. const common_chat_msg & new_msg,
  95. bool add_ass,
  96. bool use_jinja);
  97. // Returns an example of formatted chat
  98. std::string common_chat_format_example(
  99. const struct common_chat_templates * tmpls,
  100. bool use_jinja);
  101. std::string common_chat_format_name(common_chat_format format);
  102. common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
  103. common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
  104. // Parses a JSON array of messages in OpenAI's chat completion API format.
  105. // T can be std::string containing JSON or nlohmann::ordered_json
  106. template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
  107. template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
  108. // Parses a JSON array of tools in OpenAI's chat completion tool call API format.
  109. // T can be std::string containing JSON or nlohmann::ordered_json
  110. template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
  111. template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);