chat.hpp 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
  2. #pragma once
  3. #include "common.h"
  4. #include <json.hpp>
  5. #include <optional>
  6. #include <string>
  7. #include <vector>
  8. using json = nlohmann::ordered_json;
  9. struct common_chat_inputs {
  10. json messages;
  11. json tools;
  12. json tool_choice;
  13. json json_schema;
  14. bool parallel_tool_calls;
  15. bool stream;
  16. std::string grammar;
  17. bool add_generation_prompt = true;
  18. };
  19. enum common_chat_format {
  20. COMMON_CHAT_FORMAT_CONTENT_ONLY,
  21. COMMON_CHAT_FORMAT_GENERIC,
  22. COMMON_CHAT_FORMAT_MISTRAL_NEMO,
  23. COMMON_CHAT_FORMAT_LLAMA_3_X,
  24. COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  25. COMMON_CHAT_FORMAT_DEEPSEEK_R1,
  26. COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
  27. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
  28. COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  29. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  30. COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
  31. };
  32. struct common_chat_params {
  33. common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  34. json prompt;
  35. std::string grammar;
  36. bool grammar_lazy = false;
  37. std::vector<common_grammar_trigger> grammar_triggers;
  38. std::vector<std::string> additional_stops;
  39. };
  40. struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
  41. std::string common_chat_format_name(common_chat_format format);
  42. common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);