| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
- #pragma once
- #include "common.h"
- #include <functional>
- #include <chrono>
- #include <string>
- #include <vector>
- struct common_chat_templates;
- struct common_chat_tool_call {
- std::string name;
- std::string arguments;
- std::string id;
- bool operator==(const common_chat_tool_call & other) const {
- return name == other.name && arguments == other.arguments && id == other.id;
- }
- };
- struct common_chat_msg_content_part {
- std::string type;
- std::string text;
- bool operator==(const common_chat_msg_content_part & other) const {
- return type == other.type && text == other.text;
- }
- };
- struct common_chat_msg {
- std::string role;
- std::string content;
- std::vector<common_chat_msg_content_part> content_parts = {};
- std::vector<common_chat_tool_call> tool_calls = {};
- std::string reasoning_content;
- std::string tool_name;
- std::string tool_call_id;
- template <class T> T to_json_oaicompat() const;
- bool empty() const {
- return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
- }
- void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
- for (auto i = 0u; i < tool_calls.size(); i++) {
- if (ids_cache.size() <= i) {
- auto id = tool_calls[i].id;
- if (id.empty()) {
- id = gen_tool_call_id();
- }
- ids_cache.push_back(id);
- }
- tool_calls[i].id = ids_cache[i];
- }
- }
- bool operator==(const common_chat_msg & other) const {
- return role == other.role
- && content == other.content
- && content_parts == other.content_parts
- && tool_calls == other.tool_calls
- && reasoning_content == other.reasoning_content
- && tool_name == other.tool_name
- && tool_call_id == other.tool_call_id;
- }
- bool operator!=(const common_chat_msg & other) const {
- return !(*this == other);
- }
- };
- struct common_chat_msg_diff {
- // std::string reasoning_content_delta;
- std::string content_delta;
- size_t tool_call_index = std::string::npos;
- common_chat_tool_call tool_call_delta;
- static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
- bool operator==(const common_chat_msg_diff & other) const {
- return content_delta == other.content_delta
- && tool_call_index == other.tool_call_index
- && tool_call_delta == other.tool_call_delta;
- }
- };
- struct common_chat_tool {
- std::string name;
- std::string description;
- std::string parameters;
- };
- enum common_chat_tool_choice {
- COMMON_CHAT_TOOL_CHOICE_AUTO,
- COMMON_CHAT_TOOL_CHOICE_REQUIRED,
- COMMON_CHAT_TOOL_CHOICE_NONE,
- };
- enum common_chat_format {
- COMMON_CHAT_FORMAT_CONTENT_ONLY,
- COMMON_CHAT_FORMAT_GENERIC,
- COMMON_CHAT_FORMAT_MISTRAL_NEMO,
- COMMON_CHAT_FORMAT_LLAMA_3_X,
- COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
- COMMON_CHAT_FORMAT_DEEPSEEK_R1,
- COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
- COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
- COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
- COMMON_CHAT_FORMAT_HERMES_2_PRO,
- COMMON_CHAT_FORMAT_COMMAND_R7B,
- COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
- };
- struct common_chat_templates_inputs {
- std::vector<common_chat_msg> messages;
- std::string grammar;
- std::string json_schema;
- bool add_generation_prompt = true;
- bool use_jinja = true;
- // Parameters below only supported when use_jinja is true
- std::vector<common_chat_tool> tools;
- common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
- bool parallel_tool_calls = false;
- common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
- std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
- };
- struct common_chat_params {
- common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- std::string prompt;
- std::string grammar;
- bool grammar_lazy = false;
- bool thinking_forced_open = false;
- std::vector<common_grammar_trigger> grammar_triggers;
- std::vector<std::string> preserved_tokens;
- std::vector<std::string> additional_stops;
- };
- struct common_chat_syntax {
- common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
- // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
- bool reasoning_in_content = false;
- bool thinking_forced_open = false;
- };
- // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
- bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
- void common_chat_templates_free(struct common_chat_templates * tmpls);
- struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
- typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
- common_chat_templates_ptr common_chat_templates_init(
- const struct llama_model * model,
- const std::string & chat_template_override,
- const std::string & bos_token_override = "",
- const std::string & eos_token_override = "");
- bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
- const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
- struct common_chat_params common_chat_templates_apply(
- const struct common_chat_templates * tmpls,
- const struct common_chat_templates_inputs & inputs);
- // Format single message, while taking into account the position of that message in chat history
- std::string common_chat_format_single(
- const struct common_chat_templates * tmpls,
- const std::vector<common_chat_msg> & past_msg,
- const common_chat_msg & new_msg,
- bool add_ass,
- bool use_jinja);
- // Returns an example of formatted chat
- std::string common_chat_format_example(
- const struct common_chat_templates * tmpls,
- bool use_jinja);
- std::string common_chat_format_name(common_chat_format format);
- common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
- common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
- // Parses a JSON array of messages in OpenAI's chat completion API format.
- // T can be std::string containing JSON or nlohmann::ordered_json
- template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
- template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
- // Parses a JSON array of tools in OpenAI's chat completion tool call API format.
- // T can be std::string containing JSON or nlohmann::ordered_json
- template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
- template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
- template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|