|
|
@@ -1,8 +1,433 @@
|
|
|
-#include "chat.hpp"
|
|
|
-#include "chat-template.hpp"
|
|
|
+#include "chat.h"
|
|
|
#include "json-schema-to-grammar.h"
|
|
|
#include "log.h"
|
|
|
-#include "minja.hpp"
|
|
|
+#include "minja/chat-template.hpp"
|
|
|
+#include "minja/minja.hpp"
|
|
|
+
|
|
|
+#include <optional>
|
|
|
+
|
|
|
+typedef minja::chat_template common_chat_template;
|
|
|
+
|
|
|
+struct common_chat_templates {
|
|
|
+ bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
|
|
+ std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
|
|
+ std::unique_ptr<common_chat_template> template_tool_use;
|
|
|
+};
|
|
|
+
|
|
|
+struct templates_params {
|
|
|
+ json messages;
|
|
|
+ json tools;
|
|
|
+ common_chat_tool_choice tool_choice;
|
|
|
+ json json_schema;
|
|
|
+ bool parallel_tool_calls;
|
|
|
+ bool stream;
|
|
|
+ std::string grammar;
|
|
|
+ bool add_generation_prompt = true;
|
|
|
+ bool extract_reasoning = true;
|
|
|
+};
|
|
|
+
|
|
|
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
|
|
+ if (tool_choice == "auto") {
|
|
|
+ return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
|
+ }
|
|
|
+ if (tool_choice == "none") {
|
|
|
+ return COMMON_CHAT_TOOL_CHOICE_NONE;
|
|
|
+ }
|
|
|
+ if (tool_choice == "required") {
|
|
|
+ return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
+ }
|
|
|
+ throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
|
|
+ std::vector<common_chat_msg> msgs;
|
|
|
+
|
|
|
+ try {
|
|
|
+
|
|
|
+ if (!messages.is_array()) {
|
|
|
+ throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
|
|
+ }
|
|
|
+
|
|
|
+ for (const auto & message : messages) {
|
|
|
+ if (!message.is_object()) {
|
|
|
+ throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
|
|
+ }
|
|
|
+
|
|
|
+ common_chat_msg msg;
|
|
|
+ if (!message.contains("role")) {
|
|
|
+ throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
|
|
+ }
|
|
|
+ msg.role = message.at("role");
|
|
|
+
|
|
|
+ if (message.contains("content")) {
|
|
|
+ const auto & content = message.at("content");
|
|
|
+ if (content.is_string()) {
|
|
|
+ msg.content = content;
|
|
|
+ } else if (content.is_array()) {
|
|
|
+ for (const auto & part : content) {
|
|
|
+ if (!part.contains("type")) {
|
|
|
+ throw std::runtime_error("Missing content part type: " + part.dump());
|
|
|
+ }
|
|
|
+ const auto & type = part.at("type");
|
|
|
+ if (type != "text") {
|
|
|
+ throw std::runtime_error("Unsupported content part type: " + type.dump());
|
|
|
+ }
|
|
|
+ common_chat_msg_content_part msg_part;
|
|
|
+ msg_part.type = type;
|
|
|
+ msg_part.text = part.at("text");
|
|
|
+ msg.content_parts.push_back(msg_part);
|
|
|
+ }
|
|
|
+ } else if (!content.is_null()) {
|
|
|
+ throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
|
+ }
|
|
|
+ if (message.contains("reasoning_content")) {
|
|
|
+ msg.reasoning_content = message.at("reasoning_content");
|
|
|
+ }
|
|
|
+ if (message.contains("name")) {
|
|
|
+ msg.tool_name = message.at("name");
|
|
|
+ }
|
|
|
+ if (message.contains("tool_call_id")) {
|
|
|
+ msg.tool_call_id = message.at("tool_call_id");
|
|
|
+ }
|
|
|
+ if (message.contains("tool_calls")) {
|
|
|
+ for (const auto & tool_call : message.at("tool_calls")) {
|
|
|
+ common_chat_tool_call tc;
|
|
|
+ if (!tool_call.contains("type")) {
|
|
|
+ throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
|
|
+ }
|
|
|
+ const auto & type = tool_call.at("type");
|
|
|
+ if (type != "function") {
|
|
|
+ throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
|
|
+ }
|
|
|
+ if (!tool_call.contains("function")) {
|
|
|
+ throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
|
|
+ }
|
|
|
+ const auto & fc = tool_call.at("function");
|
|
|
+ if (!fc.contains("name")) {
|
|
|
+ throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
|
|
+ }
|
|
|
+ tc.name = fc.at("name");
|
|
|
+ tc.arguments = fc.at("arguments");
|
|
|
+ if (tool_call.contains("id")) {
|
|
|
+ tc.id = tool_call.at("id");
|
|
|
+ }
|
|
|
+ msg.tool_calls.push_back(tc);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ msgs.push_back(msg);
|
|
|
+ }
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
|
|
+ }
|
|
|
+
|
|
|
+ return msgs;
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
|
|
+ json messages = json::array();
|
|
|
+ for (const auto & msg : msgs) {
|
|
|
+ if (!msg.content.empty() && !msg.content_parts.empty()) {
|
|
|
+ throw std::runtime_error("Cannot specify both content and content_parts");
|
|
|
+ }
|
|
|
+ json jmsg {
|
|
|
+ {"role", msg.role},
|
|
|
+ };
|
|
|
+ if (!msg.content.empty()) {
|
|
|
+ jmsg["content"] = msg.content;
|
|
|
+ } else if (!msg.content_parts.empty()) {
|
|
|
+ if (concat_typed_text) {
|
|
|
+ std::string text;
|
|
|
+ for (const auto & part : msg.content_parts) {
|
|
|
+ if (part.type != "text") {
|
|
|
+ LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (!text.empty()) {
|
|
|
+ text += '\n';
|
|
|
+ }
|
|
|
+ text += part.text;
|
|
|
+ }
|
|
|
+ jmsg["content"] = text;
|
|
|
+ } else {
|
|
|
+ auto & parts = jmsg["content"] = json::array();
|
|
|
+ for (const auto & part : msg.content_parts) {
|
|
|
+ parts.push_back({
|
|
|
+ {"type", part.type},
|
|
|
+ {"text", part.text},
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ jmsg["content"] = json(); // null
|
|
|
+ }
|
|
|
+ if (!msg.reasoning_content.empty()) {
|
|
|
+ jmsg["reasoning_content"] = msg.reasoning_content;
|
|
|
+ }
|
|
|
+ if (!msg.tool_name.empty()) {
|
|
|
+ jmsg["name"] = msg.tool_name;
|
|
|
+ }
|
|
|
+ if (!msg.tool_call_id.empty()) {
|
|
|
+ jmsg["tool_call_id"] = msg.tool_call_id;
|
|
|
+ }
|
|
|
+ if (!msg.tool_calls.empty()) {
|
|
|
+ auto & tool_calls = jmsg["tool_calls"] = json::array();
|
|
|
+ for (const auto & tool_call : msg.tool_calls) {
|
|
|
+ json tc {
|
|
|
+ {"type", "function"},
|
|
|
+ {"function", {
|
|
|
+ {"name", tool_call.name},
|
|
|
+ {"arguments", tool_call.arguments},
|
|
|
+ }},
|
|
|
+ };
|
|
|
+ if (!tool_call.id.empty()) {
|
|
|
+ tc["id"] = tool_call.id;
|
|
|
+ }
|
|
|
+ tool_calls.push_back(tc);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ messages.push_back(jmsg);
|
|
|
+ }
|
|
|
+ return messages;
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
|
|
+ return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
|
|
+ std::vector<common_chat_tool> result;
|
|
|
+
|
|
|
+ try {
|
|
|
+ if (!tools.is_null()) {
|
|
|
+ if (!tools.is_array()) {
|
|
|
+ throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
|
|
+ }
|
|
|
+ for (const auto & tool : tools) {
|
|
|
+ if (!tool.contains("type")) {
|
|
|
+ throw std::runtime_error("Missing tool type: " + tool.dump());
|
|
|
+ }
|
|
|
+ const auto & type = tool.at("type");
|
|
|
+ if (!type.is_string() || type != "function") {
|
|
|
+ throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
|
|
+ }
|
|
|
+ if (!tool.contains("function")) {
|
|
|
+ throw std::runtime_error("Missing tool function: " + tool.dump());
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto & function = tool.at("function");
|
|
|
+ result.push_back({
|
|
|
+ /* .name = */ function.at("name"),
|
|
|
+ /* .description = */ function.at("description"),
|
|
|
+ /* .parameters = */ function.at("parameters").dump(),
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
|
|
+ return common_chat_tools_parse_oaicompat(json::parse(tools));
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
|
|
+ if (tools.empty()) {
|
|
|
+ return json();
|
|
|
+ }
|
|
|
+
|
|
|
+ auto result = json::array();
|
|
|
+ for (const auto & tool : tools) {
|
|
|
+ result.push_back({
|
|
|
+ {"type", "function"},
|
|
|
+ {"function", {
|
|
|
+ {"name", tool.name},
|
|
|
+ {"description", tool.description},
|
|
|
+ {"parameters", json::parse(tool.parameters)},
|
|
|
+ }},
|
|
|
+ });
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|
|
+ if (use_jinja) {
|
|
|
+ try {
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "user";
|
|
|
+ msg.content = "test";
|
|
|
+
|
|
|
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
|
|
+
|
|
|
+ common_chat_templates_inputs inputs;
|
|
|
+ inputs.messages = {msg};
|
|
|
+
|
|
|
+ common_chat_templates_apply(tmpls.get(), inputs);
|
|
|
+ return true;
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ llama_chat_message chat[] = {{"user", "test"}};
|
|
|
+ const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
|
+ return res >= 0;
|
|
|
+}
|
|
|
+
|
|
|
+std::string common_chat_format_single(
|
|
|
+ const struct common_chat_templates * tmpls,
|
|
|
+ const std::vector<common_chat_msg> & past_msg,
|
|
|
+ const common_chat_msg & new_msg,
|
|
|
+ bool add_ass,
|
|
|
+ bool use_jinja) {
|
|
|
+
|
|
|
+ common_chat_templates_inputs inputs;
|
|
|
+ inputs.use_jinja = use_jinja;
|
|
|
+
|
|
|
+ std::string fmt_past_msg;
|
|
|
+ if (!past_msg.empty()) {
|
|
|
+ inputs.messages = past_msg;
|
|
|
+ inputs.add_generation_prompt = false;
|
|
|
+ fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
|
+ }
|
|
|
+ std::ostringstream ss;
|
|
|
+ // if the past_msg ends with a newline, we must preserve it in the formatted version
|
|
|
+ if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
|
|
+ ss << "\n";
|
|
|
+ };
|
|
|
+ // format chat with new_msg
|
|
|
+ inputs.messages.push_back(new_msg);
|
|
|
+ inputs.add_generation_prompt = add_ass;
|
|
|
+ auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
|
+ // get the diff part
|
|
|
+ ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
|
|
+ return ss.str();
|
|
|
+}
|
|
|
+
|
|
|
+std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
|
|
+ common_chat_templates_inputs inputs;
|
|
|
+ inputs.use_jinja = use_jinja;
|
|
|
+ auto add_simple_msg = [&](auto role, auto content) {
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = role;
|
|
|
+ msg.content = content;
|
|
|
+ inputs.messages.push_back(msg);
|
|
|
+ };
|
|
|
+ add_simple_msg("system", "You are a helpful assistant");
|
|
|
+ add_simple_msg("user", "Hello");
|
|
|
+ add_simple_msg("assistant", "Hi there");
|
|
|
+ add_simple_msg("user", "How are you?");
|
|
|
+ return common_chat_templates_apply(tmpls, inputs).prompt;
|
|
|
+}
|
|
|
+
|
|
|
+#define CHATML_TEMPLATE_SRC \
|
|
|
+ "{%- for message in messages -%}\n" \
|
|
|
+ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
|
|
+ "{%- endfor -%}\n" \
|
|
|
+ "{%- if add_generation_prompt -%}\n" \
|
|
|
+ " {{- '<|im_start|>assistant\n' -}}\n" \
|
|
|
+ "{%- endif -%}"
|
|
|
+
|
|
|
+void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
|
|
+ delete tmpls;
|
|
|
+}
|
|
|
+
|
|
|
+bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
|
|
+ return tmpls->has_explicit_template;
|
|
|
+}
|
|
|
+
|
|
|
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
|
|
+ if (variant != nullptr) {
|
|
|
+ if (strcmp(variant, "tool_use") == 0) {
|
|
|
+ if (tmpls->template_tool_use) {
|
|
|
+ return tmpls->template_tool_use->source().c_str();
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+ } else {
|
|
|
+ LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return tmpls->template_default->source().c_str();
|
|
|
+}
|
|
|
+
|
|
|
+common_chat_templates_ptr common_chat_templates_init(
|
|
|
+ const struct llama_model * model,
|
|
|
+ const std::string & chat_template_override,
|
|
|
+ const std::string & bos_token_override,
|
|
|
+ const std::string & eos_token_override)
|
|
|
+{
|
|
|
+ std::string default_template_src;
|
|
|
+ std::string template_tool_use_src;
|
|
|
+
|
|
|
+ bool has_explicit_template = !chat_template_override.empty();
|
|
|
+ if (chat_template_override.empty()) {
|
|
|
+ GGML_ASSERT(model != nullptr);
|
|
|
+ const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
|
|
+ if (str) {
|
|
|
+ default_template_src = str;
|
|
|
+ has_explicit_template = true;
|
|
|
+ }
|
|
|
+ str = llama_model_chat_template(model, /* name */ "tool_use");
|
|
|
+ if (str) {
|
|
|
+ template_tool_use_src = str;
|
|
|
+ has_explicit_template = true;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ default_template_src = chat_template_override;
|
|
|
+ }
|
|
|
+ if (default_template_src.empty() || default_template_src == "chatml") {
|
|
|
+ if (!template_tool_use_src.empty()) {
|
|
|
+ default_template_src = template_tool_use_src;
|
|
|
+ } else {
|
|
|
+ default_template_src = CHATML_TEMPLATE_SRC;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ std::string token_bos = bos_token_override;
|
|
|
+ std::string token_eos = eos_token_override;
|
|
|
+ if (model) {
|
|
|
+ const auto * vocab = llama_model_get_vocab(model);
|
|
|
+ const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
|
|
+ if (token == LLAMA_TOKEN_NULL) {
|
|
|
+ if (default_template_src.find(jinja_variable_name) != std::string::npos
|
|
|
+ || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
|
|
+ LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
|
|
+ }
|
|
|
+ return std::string();
|
|
|
+ }
|
|
|
+ return common_token_to_piece(vocab, token, true);
|
|
|
+ };
|
|
|
+ token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
|
|
+ token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
|
|
+ }
|
|
|
+ common_chat_templates_ptr tmpls(new common_chat_templates());
|
|
|
+ tmpls->has_explicit_template = has_explicit_template;
|
|
|
+ try {
|
|
|
+ tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
|
|
+ tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
|
|
+ }
|
|
|
+ if (!template_tool_use_src.empty()) {
|
|
|
+ try {
|
|
|
+ tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
|
|
+ } catch (const std::exception & e) {
|
|
|
+ LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return tmpls;
|
|
|
+}
|
|
|
|
|
|
std::string common_chat_format_name(common_chat_format format) {
|
|
|
switch (format) {
|
|
|
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|
|
|
|
|
json_error_locator() : position(0), found_error(false) {}
|
|
|
|
|
|
- bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
|
|
+ bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
|
|
this->position = position - 1;
|
|
|
this->found_error = true;
|
|
|
return false;
|
|
|
}
|
|
|
- bool null() override { return true; }
|
|
|
- bool boolean(bool) override { return true; }
|
|
|
- bool number_integer(number_integer_t) override { return true; }
|
|
|
- bool number_unsigned(number_unsigned_t) override { return true; }
|
|
|
- bool number_float(number_float_t, const string_t &) override { return true; }
|
|
|
- bool string(string_t &) override { return true; }
|
|
|
- bool binary(binary_t &) override { return true; }
|
|
|
- bool start_object(std::size_t) override { return true; }
|
|
|
- bool key(string_t &) override { return true; }
|
|
|
+ bool null() override { return true; } // NOLINT
|
|
|
+ bool boolean(bool) override { return true; } // NOLINT
|
|
|
+ bool number_integer(number_integer_t) override { return true; } // NOLINT
|
|
|
+ bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
|
|
+ bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
|
|
+ bool string(string_t &) override { return true; } // NOLINT
|
|
|
+ bool binary(binary_t &) override { return true; } // NOLINT
|
|
|
+ bool start_object(std::size_t) override { return true; } // NOLINT
|
|
|
+ bool key(string_t &) override { return true; } // NOLINT
|
|
|
bool end_object() override { return true; }
|
|
|
- bool start_array(std::size_t) override { return true; }
|
|
|
+ bool start_array(std::size_t) override { return true; } // NOLINT
|
|
|
bool end_array() override { return true; }
|
|
|
};
|
|
|
json_error_locator err_loc;
|
|
|
@@ -187,13 +612,20 @@ static std::string apply(
|
|
|
// tmpl_inputs.now = std::chrono::system_clock::now();
|
|
|
|
|
|
minja::chat_template_options tmpl_opts;
|
|
|
- tmpl_opts.use_bos_token = false;
|
|
|
- tmpl_opts.use_eos_token = false;
|
|
|
-
|
|
|
- return tmpl.apply(tmpl_inputs, tmpl_opts);
|
|
|
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
|
|
+ // instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
|
|
+ // may be needed inside the template / between messages too.
|
|
|
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
|
|
+ if (string_starts_with(result, tmpl.bos_token())) {
|
|
|
+ result = result.substr(tmpl.bos_token().size());
|
|
|
+ }
|
|
|
+ if (string_ends_with(result, tmpl.eos_token())) {
|
|
|
+ result = result.substr(0, result.size() - tmpl.eos_token().size());
|
|
|
+ }
|
|
|
+ return result;
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
|
|
|
auto tool_call_schemas = json::array();
|
|
|
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|
|
{"required", json::array({"tool_call"})},
|
|
|
};
|
|
|
const auto schema =
|
|
|
- inputs.tool_choice != "required"
|
|
|
+ inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
|
|
? json {
|
|
|
{"anyOf", json::array({
|
|
|
tool_call,
|
|
|
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
auto schemas = json::array();
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
|
|
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
auto schemas = json::array();
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
|
const auto & parameters_required = parameters.at("required");
|
|
|
for (const auto & prop : expected_properties) {
|
|
|
if (!parameters_properties.contains(prop)) {
|
|
|
- throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
|
|
+ throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
|
|
}
|
|
|
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
|
|
- throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
|
|
+ throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
|
|
}
|
|
|
}
|
|
|
if (parameters_properties.size() != expected_properties.size()) {
|
|
|
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
|
|
+static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
|
|
auto builtin_tools = json::array();
|
|
|
common_chat_params data;
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
std::vector<std::string> tool_rules;
|
|
|
|
|
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
|
|
- if (name == "wolfram_alpha") {
|
|
|
+ if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
|
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
|
|
- expect_tool_parameters(name, parameters, {"query"});
|
|
|
- } else if (name == "web_search" || name == "brave_search") {
|
|
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
|
|
expect_tool_parameters(name, parameters, {"query"});
|
|
|
} else if (name == "python" || name == "code_interpreter") {
|
|
|
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
|
|
|
|
std::vector<std::string> kvs;
|
|
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
|
|
- kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
|
|
+ kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
|
|
}
|
|
|
|
|
|
tool_rules.push_back(
|
|
|
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
|
|
auto arg_value_str = raw_args.substr(it_eq + 1);
|
|
|
auto arg_value = json::parse(arg_value_str);
|
|
|
|
|
|
- return {
|
|
|
- /* .role = */ "assistant",
|
|
|
- /* .content = */ match.prefix().str(),
|
|
|
- /* .tool_calls = */ {
|
|
|
- {
|
|
|
- /* .name = */ match[1],
|
|
|
- /* .arguments = */ (json {
|
|
|
- {arg_name, arg_value},
|
|
|
- }).dump(),
|
|
|
- /* .id = */ "",
|
|
|
- },
|
|
|
- },
|
|
|
- };
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "assistant";
|
|
|
+ msg.content = match.prefix().str();
|
|
|
+ msg.tool_calls.push_back({
|
|
|
+ /* .name = */ name,
|
|
|
+ /* .arguments = */ (json {
|
|
|
+ {arg_name, arg_value},
|
|
|
+ }).dump(),
|
|
|
+ /* .id = */ "",
|
|
|
+ });
|
|
|
+ return msg;
|
|
|
}
|
|
|
}
|
|
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
std::vector<std::string> tool_rules;
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
const auto & function = tool.at("function");
|
|
|
std::string name = function.at("name");
|
|
|
auto parameters = function.at("parameters");
|
|
|
+ builder.resolve_refs(parameters);
|
|
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
|
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
|
|
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
|
|
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
|
|
|
return msg;
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
- fprintf(stderr, "%s\n", __func__);
|
|
|
+static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
+ LOG_DBG("%s\n", __func__);
|
|
|
common_chat_params data;
|
|
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
|
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
|
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
|
|
});
|
|
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
auto schemas = json::array();
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
|
|
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
|
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
|
|
common_chat_params data;
|
|
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
|
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
std::vector<std::string> first_tool_rules;
|
|
|
std::vector<std::string> subsequent_tool_rules;
|
|
|
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|
|
const auto & function = tool.at("function");
|
|
|
std::string name = function.at("name");
|
|
|
auto parameters = function.at("parameters");
|
|
|
+ builder.resolve_refs(parameters);
|
|
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
|
|
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
|
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
|
|
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
|
|
common_chat_params data;
|
|
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
|
|
std::string python_code_argument_name;
|
|
|
auto has_raw_python = false;
|
|
|
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
std::vector<std::string> tool_rules;
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
|
throw std::runtime_error("Missing type in python tool");
|
|
|
}
|
|
|
has_raw_python = true;
|
|
|
- auto type = parameters.at("type");
|
|
|
+ const auto & type = parameters.at("type");
|
|
|
if (type == "object") {
|
|
|
auto properties = parameters.at("properties");
|
|
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
|
|
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
|
|
std::smatch match;
|
|
|
if (std::regex_search(input, match, python_tag_regex)) {
|
|
|
auto code = match[1].str();
|
|
|
- return {
|
|
|
- /* .role = */ "assistant",
|
|
|
- /* .content = */ match.prefix().str(),
|
|
|
- /* .tool_calls = */ {
|
|
|
- {
|
|
|
- /* .name = */ "python",
|
|
|
- /* .arguments = */ (json {{"code", code}}).dump(),
|
|
|
- /* .id = */ "",
|
|
|
- },
|
|
|
- }
|
|
|
- };
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "assistant";
|
|
|
+ msg.content = match.prefix().str();
|
|
|
+ msg.tool_calls.push_back({
|
|
|
+ /* .name = */ "python",
|
|
|
+ /* .arguments = */ (json {{"code", code}}).dump(),
|
|
|
+ /* .id = */ "",
|
|
|
+ });
|
|
|
+ return msg;
|
|
|
}
|
|
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
|
|
static std::regex close_regex(R"(</function>)");
|
|
|
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
|
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
|
|
- data.grammar_lazy = inputs.tool_choice != "required";
|
|
|
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
|
std::vector<std::string> tool_rules;
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
|
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|
|
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
|
|
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
|
|
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "assistant";
|
|
|
+
|
|
|
auto end = input.end();
|
|
|
std::sregex_iterator rend;
|
|
|
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
|
|
if (rit == rend) {
|
|
|
- return {
|
|
|
- /* .role = */ "assistant",
|
|
|
- /* .content = */ input,
|
|
|
- /* .tool_calls = */ {},
|
|
|
- };
|
|
|
+ msg.content = input;
|
|
|
+ return msg;
|
|
|
}
|
|
|
|
|
|
- common_chat_msg result;
|
|
|
- result.role = "assistant";
|
|
|
- result.content = rit->prefix();
|
|
|
+ msg.content = rit->prefix();
|
|
|
|
|
|
auto it = rit->suffix().first;
|
|
|
while (it != end) {
|
|
|
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|
|
throw std::runtime_error("Failed to parse json tool call");
|
|
|
}
|
|
|
const auto & arguments = call.at("arguments");
|
|
|
- result.tool_calls.push_back({
|
|
|
+ msg.tool_calls.push_back({
|
|
|
call.at("name"),
|
|
|
arguments.dump(),
|
|
|
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
|
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
- return result;
|
|
|
+ return msg;
|
|
|
} catch (const std::exception & e) {
|
|
|
- return {
|
|
|
- /* .role = */ "assistant",
|
|
|
- /* .content = */ input,
|
|
|
- /* .tool_calls = */ {},
|
|
|
- };
|
|
|
+ LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "assistant";
|
|
|
+ msg.content = input;
|
|
|
+ return msg;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
|
common_chat_params data;
|
|
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
|
|
return data;
|
|
|
}
|
|
|
|
|
|
-common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
|
|
+static common_chat_params common_chat_templates_apply_jinja(
|
|
|
+ const struct common_chat_templates * tmpls,
|
|
|
+ const struct common_chat_templates_inputs & inputs)
|
|
|
+{
|
|
|
+ templates_params params;
|
|
|
+ params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
|
|
+ const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
|
|
+ ? *tmpls->template_tool_use
|
|
|
+ : *tmpls->template_default;
|
|
|
const auto & src = tmpl.source();
|
|
|
const auto & caps = tmpl.original_caps();
|
|
|
+ params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
|
|
+ params.add_generation_prompt = inputs.add_generation_prompt;
|
|
|
+ params.extract_reasoning = inputs.extract_reasoning;
|
|
|
+ params.tool_choice = inputs.tool_choice;
|
|
|
+ params.grammar = inputs.grammar;
|
|
|
+ if (!inputs.json_schema.empty()) {
|
|
|
+ params.json_schema = json::parse(inputs.json_schema);
|
|
|
+ }
|
|
|
|
|
|
- if (inputs.tools.is_array()) {
|
|
|
- if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
|
|
|
+ if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
|
|
+ LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
|
|
+ params.parallel_tool_calls = false;
|
|
|
+ } else {
|
|
|
+ params.parallel_tool_calls = inputs.parallel_tool_calls;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (params.tools.is_array()) {
|
|
|
+ if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
|
|
throw std::runtime_error("Cannot specify grammar with tools");
|
|
|
}
|
|
|
if (caps.supports_tool_calls && !caps.supports_tools) {
|
|
|
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|
|
}
|
|
|
|
|
|
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
|
|
- if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
|
|
|
- return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
|
|
+ if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
|
|
+ return common_chat_params_init_deepseek_r1(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
|
|
- if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
|
|
- return common_chat_params_init_command_r7b(tmpl, inputs);
|
|
|
+ if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
|
|
+ return common_chat_params_init_command_r7b(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Use generic handler when mixing tools + JSON schema.
|
|
|
// TODO: support that mix in handlers below.
|
|
|
- if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
|
|
|
- return common_chat_params_init_generic(tmpl, inputs);
|
|
|
+ if ((params.tools.is_array() && params.json_schema.is_object())) {
|
|
|
+ return common_chat_params_init_generic(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
|
|
if (src.find(">>>all") != std::string::npos) {
|
|
|
- return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
|
|
+ return common_chat_params_init_functionary_v3_2(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
|
|
if (src.find(" functools[") != std::string::npos) {
|
|
|
- return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
|
|
+ return common_chat_params_init_firefunction_v2(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Plain handler (no tools)
|
|
|
- if (inputs.tools.is_null() || inputs.tool_choice == "none") {
|
|
|
- return common_chat_params_init_without_tools(tmpl, inputs);
|
|
|
+ if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
|
+ return common_chat_params_init_without_tools(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
|
|
if (src.find("<tool_call>") != std::string::npos) {
|
|
|
- return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
|
|
+ return common_chat_params_init_hermes_2_pro(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Functionary v3.1 (w/ tools)
|
|
|
if (src.find("<|start_header_id|>") != std::string::npos
|
|
|
&& src.find("<function=") != std::string::npos) {
|
|
|
- return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
|
|
+ return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
|
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
|
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
|
|
- return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
|
|
+ return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
|
|
}
|
|
|
|
|
|
// Mistral Nemo (w/ tools)
|
|
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
|
|
- return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
|
|
+ return common_chat_params_init_mistral_nemo(tmpl, params);
|
|
|
}
|
|
|
|
|
|
// Generic fallback
|
|
|
- return common_chat_params_init_generic(tmpl, inputs);
|
|
|
+ return common_chat_params_init_generic(tmpl, params);
|
|
|
+}
|
|
|
+
|
|
|
+// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
|
|
+static common_chat_params common_chat_templates_apply_legacy(
|
|
|
+ const struct common_chat_templates * tmpls,
|
|
|
+ const struct common_chat_templates_inputs & inputs)
|
|
|
+{
|
|
|
+ int alloc_size = 0;
|
|
|
+ std::vector<llama_chat_message> chat;
|
|
|
+ std::vector<std::string> contents;
|
|
|
+ for (const auto & msg : inputs.messages) {
|
|
|
+ auto content = msg.content;
|
|
|
+ for (const auto & part : msg.content_parts) {
|
|
|
+ if (part.type != "text") {
|
|
|
+ LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (!content.empty()) {
|
|
|
+ content += "\n";;
|
|
|
+ }
|
|
|
+ content += part.text;
|
|
|
+ }
|
|
|
+ contents.emplace_back(std::move(content));
|
|
|
+ }
|
|
|
+ for (size_t i = 0; i < contents.size(); ++i) {
|
|
|
+ const auto & msg = inputs.messages[i];
|
|
|
+ const auto & content = contents[i];
|
|
|
+ chat.push_back({msg.role.c_str(), content.c_str()});
|
|
|
+ alloc_size += (msg.role.size() + content.size()) * 1.25;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::vector<char> buf(alloc_size);
|
|
|
+
|
|
|
+ // run the first time to get the total output length
|
|
|
+ const auto & src = tmpls->template_default->source();
|
|
|
+ int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
|
|
+
|
|
|
+ // error: chat template is not supported
|
|
|
+ if (res < 0) {
|
|
|
+ // if the custom "tmpl" is not supported, we throw an error
|
|
|
+ // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
|
|
+ throw std::runtime_error("this custom template is not supported");
|
|
|
+ }
|
|
|
+
|
|
|
+ // if it turns out that our buffer is too small, we resize it
|
|
|
+ if ((size_t) res > buf.size()) {
|
|
|
+ buf.resize(res);
|
|
|
+ res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
|
|
+ }
|
|
|
+
|
|
|
+ common_chat_params params;
|
|
|
+ params.prompt = std::string(buf.data(), res);
|
|
|
+ if (!inputs.json_schema.empty()) {
|
|
|
+ params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
|
|
+ } else {
|
|
|
+ params.grammar = inputs.grammar;
|
|
|
+ }
|
|
|
+ return params;
|
|
|
+}
|
|
|
+
|
|
|
+common_chat_params common_chat_templates_apply(
|
|
|
+ const struct common_chat_templates * tmpls,
|
|
|
+ const struct common_chat_templates_inputs & inputs)
|
|
|
+{
|
|
|
+ GGML_ASSERT(tmpls != nullptr);
|
|
|
+ return inputs.use_jinja
|
|
|
+ ? common_chat_templates_apply_jinja(tmpls, inputs)
|
|
|
+ : common_chat_templates_apply_legacy(tmpls, inputs);
|
|
|
}
|
|
|
|
|
|
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
|
|
- return {
|
|
|
- /* .role = */ "assistant",
|
|
|
- /* .content = */ input,
|
|
|
- /* .tool_calls = */ {},
|
|
|
- };
|
|
|
+ common_chat_msg msg;
|
|
|
+ msg.role = "assistant";
|
|
|
+ msg.content = input;
|
|
|
+ return msg;
|
|
|
}
|
|
|
|
|
|
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|