1
0

chat.hpp 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
  2. #pragma once
  3. #include "common.h"
  4. #include <json.hpp>
  5. #include <optional>
  6. #include <string>
  7. #include <vector>
  8. using json = nlohmann::ordered_json;
  9. struct common_chat_inputs {
  10. json messages;
  11. json tools;
  12. json tool_choice;
  13. json json_schema;
  14. bool parallel_tool_calls;
  15. bool stream;
  16. std::string grammar;
  17. bool add_generation_prompt = true;
  18. };
  19. enum common_chat_format {
  20. COMMON_CHAT_FORMAT_CONTENT_ONLY,
  21. COMMON_CHAT_FORMAT_GENERIC,
  22. COMMON_CHAT_FORMAT_MISTRAL_NEMO,
  23. COMMON_CHAT_FORMAT_LLAMA_3_X,
  24. COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  25. COMMON_CHAT_FORMAT_DEEPSEEK_R1,
  26. COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
  27. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
  28. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  29. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  30. COMMON_CHAT_FORMAT_COMMAND_R7B,
  31. COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
  32. };
  33. struct common_chat_params {
  34. common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  35. json prompt;
  36. std::string grammar;
  37. bool grammar_lazy = false;
  38. std::vector<common_grammar_trigger> grammar_triggers;
  39. std::vector<std::string> preserved_tokens;
  40. std::vector<std::string> additional_stops;
  41. };
  42. struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
  43. std::string common_chat_format_name(common_chat_format format);
  44. common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);