Kaynağa Gözat

tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)

* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
Olivier Chafik 11 ay önce
ebeveyn
işleme
63e489c025

+ 1 - 1
Makefile

@@ -1364,7 +1364,7 @@ llama-server: \
 	examples/server/index.html.hpp \
 	examples/server/loading.html.hpp \
 	common/chat.cpp \
-	common/chat.hpp \
+	common/chat.h \
 	common/chat-template.hpp \
 	common/json.hpp \
 	common/minja.hpp \

+ 3 - 3
common/CMakeLists.txt

@@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
     arg.h
     base64.hpp
     chat.cpp
-    chat.hpp
-    chat-template.hpp
+    chat.h
     common.cpp
     common.h
     console.cpp
@@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
     llguidance.cpp
     log.cpp
     log.h
-    minja.hpp
+    minja/chat-template.hpp
+    minja/minja.hpp
     ngram-cache.cpp
     ngram-cache.h
     sampling.cpp

+ 1 - 0
common/arg.cpp

@@ -2,6 +2,7 @@
 
 #include "log.h"
 #include "sampling.h"
+#include "chat.h"
 
 #include <algorithm>
 #include <climits>

+ 623 - 107
common/chat.cpp

@@ -1,8 +1,433 @@
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
 #include "json-schema-to-grammar.h"
 #include "log.h"
-#include "minja.hpp"
+#include "minja/chat-template.hpp"
+#include "minja/minja.hpp"
+
+#include <optional>
+
+typedef minja::chat_template common_chat_template;
+
+struct common_chat_templates {
+    bool has_explicit_template; // Model had builtin template or template overridde was specified.
+    std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
+    std::unique_ptr<common_chat_template> template_tool_use;
+};
+
+struct templates_params {
+    json messages;
+    json tools;
+    common_chat_tool_choice tool_choice;
+    json json_schema;
+    bool parallel_tool_calls;
+    bool stream;
+    std::string grammar;
+    bool add_generation_prompt = true;
+    bool extract_reasoning     = true;
+};
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
+    if (tool_choice == "auto") {
+        return COMMON_CHAT_TOOL_CHOICE_AUTO;
+    }
+    if (tool_choice == "none") {
+        return COMMON_CHAT_TOOL_CHOICE_NONE;
+    }
+    if (tool_choice == "required") {
+        return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+    }
+    throw std::runtime_error("Invalid tool_choice: " + tool_choice);
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
+    std::vector<common_chat_msg> msgs;
+
+    try {
+
+        if (!messages.is_array()) {
+            throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
+        }
+
+        for (const auto & message : messages) {
+            if (!message.is_object()) {
+                throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
+            }
+
+            common_chat_msg msg;
+            if (!message.contains("role")) {
+                throw std::runtime_error("Missing 'role' in message: " + message.dump());
+            }
+            msg.role = message.at("role");
+
+            if (message.contains("content")) {
+                const auto & content = message.at("content");
+                if (content.is_string()) {
+                    msg.content = content;
+                } else if (content.is_array()) {
+                    for (const auto & part : content) {
+                        if (!part.contains("type")) {
+                            throw std::runtime_error("Missing content part type: " + part.dump());
+                        }
+                        const auto & type = part.at("type");
+                        if (type != "text") {
+                            throw std::runtime_error("Unsupported content part type: " + type.dump());
+                        }
+                        common_chat_msg_content_part msg_part;
+                        msg_part.type = type;
+                        msg_part.text = part.at("text");
+                        msg.content_parts.push_back(msg_part);
+                    }
+                } else if (!content.is_null()) {
+                    throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+                }
+            } else {
+                throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+            }
+            if (message.contains("reasoning_content")) {
+                msg.reasoning_content = message.at("reasoning_content");
+            }
+            if (message.contains("name")) {
+                msg.tool_name = message.at("name");
+            }
+            if (message.contains("tool_call_id")) {
+                msg.tool_call_id = message.at("tool_call_id");
+            }
+            if (message.contains("tool_calls")) {
+                for (const auto & tool_call : message.at("tool_calls")) {
+                    common_chat_tool_call tc;
+                    if (!tool_call.contains("type")) {
+                        throw std::runtime_error("Missing tool call type: " + tool_call.dump());
+                    }
+                    const auto & type = tool_call.at("type");
+                    if (type != "function") {
+                        throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
+                    }
+                    if (!tool_call.contains("function")) {
+                        throw std::runtime_error("Missing tool call function: " + tool_call.dump());
+                    }
+                    const auto & fc = tool_call.at("function");
+                    if (!fc.contains("name")) {
+                        throw std::runtime_error("Missing tool call name: " + tool_call.dump());
+                    }
+                    tc.name = fc.at("name");
+                    tc.arguments = fc.at("arguments");
+                    if (tool_call.contains("id")) {
+                        tc.id = tool_call.at("id");
+                    }
+                    msg.tool_calls.push_back(tc);
+                }
+            }
+
+            msgs.push_back(msg);
+        }
+    } catch (const std::exception & e) {
+        throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
+    }
+
+    return msgs;
+}
+
+template <>
+json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
+    json messages = json::array();
+    for (const auto & msg : msgs) {
+        if (!msg.content.empty() && !msg.content_parts.empty()) {
+            throw std::runtime_error("Cannot specify both content and content_parts");
+        }
+        json jmsg {
+            {"role", msg.role},
+        };
+        if (!msg.content.empty()) {
+            jmsg["content"] = msg.content;
+        } else if (!msg.content_parts.empty()) {
+            if (concat_typed_text) {
+                std::string text;
+                for (const auto & part : msg.content_parts) {
+                    if (part.type != "text") {
+                        LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
+                        continue;
+                    }
+                    if (!text.empty()) {
+                        text += '\n';
+                    }
+                    text += part.text;
+                }
+                jmsg["content"] = text;
+            } else {
+                auto & parts = jmsg["content"] = json::array();
+                for (const auto & part : msg.content_parts) {
+                    parts.push_back({
+                        {"type", part.type},
+                        {"text", part.text},
+                    });
+                }
+            }
+        } else {
+            jmsg["content"] = json(); // null
+        }
+        if (!msg.reasoning_content.empty()) {
+            jmsg["reasoning_content"] = msg.reasoning_content;
+        }
+        if (!msg.tool_name.empty()) {
+            jmsg["name"] = msg.tool_name;
+        }
+        if (!msg.tool_call_id.empty()) {
+            jmsg["tool_call_id"] = msg.tool_call_id;
+        }
+        if (!msg.tool_calls.empty()) {
+            auto & tool_calls = jmsg["tool_calls"] = json::array();
+            for (const auto & tool_call : msg.tool_calls) {
+                json tc {
+                    {"type", "function"},
+                    {"function", {
+                        {"name", tool_call.name},
+                        {"arguments", tool_call.arguments},
+                    }},
+                };
+                if (!tool_call.id.empty()) {
+                    tc["id"] = tool_call.id;
+                }
+                tool_calls.push_back(tc);
+            }
+        }
+        messages.push_back(jmsg);
+    }
+    return messages;
+}
+
+template <>
+std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
+    return common_chat_msgs_parse_oaicompat(json::parse(messages));
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
+    std::vector<common_chat_tool> result;
+
+    try {
+        if (!tools.is_null()) {
+            if (!tools.is_array()) {
+                throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
+            }
+            for (const auto & tool : tools) {
+                if (!tool.contains("type")) {
+                    throw std::runtime_error("Missing tool type: " + tool.dump());
+                }
+                const auto & type = tool.at("type");
+                if (!type.is_string() || type != "function") {
+                    throw std::runtime_error("Unsupported tool type: " + tool.dump());
+                }
+                if (!tool.contains("function")) {
+                    throw std::runtime_error("Missing tool function: " + tool.dump());
+                }
+
+                const auto & function = tool.at("function");
+                result.push_back({
+                    /* .name = */ function.at("name"),
+                    /* .description = */ function.at("description"),
+                    /* .parameters = */ function.at("parameters").dump(),
+                });
+            }
+        }
+    } catch (const std::exception & e) {
+        throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
+    }
+
+    return result;
+}
+
+template <>
+std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
+    return common_chat_tools_parse_oaicompat(json::parse(tools));
+}
+
+template <>
+json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
+    if (tools.empty()) {
+        return json();
+    }
+
+    auto result = json::array();
+    for (const auto & tool : tools) {
+        result.push_back({
+            {"type", "function"},
+            {"function", {
+                {"name", tool.name},
+                {"description", tool.description},
+                {"parameters", json::parse(tool.parameters)},
+            }},
+        });
+    }
+    return result;
+}
+
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
+    if (use_jinja) {
+        try {
+            common_chat_msg msg;
+            msg.role = "user";
+            msg.content = "test";
+
+            auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
+
+            common_chat_templates_inputs inputs;
+            inputs.messages = {msg};
+
+            common_chat_templates_apply(tmpls.get(), inputs);
+            return true;
+        } catch (const std::exception & e) {
+            LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
+            return false;
+        }
+    }
+    llama_chat_message chat[] = {{"user", "test"}};
+    const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
+    return res >= 0;
+}
+
+std::string common_chat_format_single(
+        const struct common_chat_templates * tmpls,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass,
+        bool use_jinja) {
+
+    common_chat_templates_inputs inputs;
+    inputs.use_jinja = use_jinja;
+
+    std::string fmt_past_msg;
+    if (!past_msg.empty()) {
+        inputs.messages = past_msg;
+        inputs.add_generation_prompt = false;
+        fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+    }
+    std::ostringstream ss;
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
+    inputs.messages.push_back(new_msg);
+    inputs.add_generation_prompt = add_ass;
+    auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
+}
+
+std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
+    common_chat_templates_inputs inputs;
+    inputs.use_jinja = use_jinja;
+    auto add_simple_msg = [&](auto role, auto content) {
+        common_chat_msg msg;
+        msg.role = role;
+        msg.content = content;
+        inputs.messages.push_back(msg);
+    };
+    add_simple_msg("system",    "You are a helpful assistant");
+    add_simple_msg("user",      "Hello");
+    add_simple_msg("assistant", "Hi there");
+    add_simple_msg("user",      "How are you?");
+    return common_chat_templates_apply(tmpls, inputs).prompt;
+}
+
+#define CHATML_TEMPLATE_SRC \
+    "{%- for message in messages -%}\n" \
+    "  {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
+    "{%- endfor -%}\n" \
+    "{%- if add_generation_prompt -%}\n" \
+    "  {{- '<|im_start|>assistant\n' -}}\n" \
+    "{%- endif -%}"
+
+void common_chat_templates_free(struct common_chat_templates * tmpls) {
+    delete tmpls;
+}
+
+bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
+    return tmpls->has_explicit_template;
+}
+
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
+    if (variant != nullptr) {
+        if (strcmp(variant, "tool_use") == 0) {
+            if (tmpls->template_tool_use) {
+                return tmpls->template_tool_use->source().c_str();
+            }
+            return nullptr;
+        } else {
+            LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
+        }
+    }
+    return tmpls->template_default->source().c_str();
+}
+
+common_chat_templates_ptr common_chat_templates_init(
+    const struct llama_model * model,
+    const std::string & chat_template_override,
+    const std::string & bos_token_override,
+    const std::string & eos_token_override)
+{
+    std::string default_template_src;
+    std::string template_tool_use_src;
+
+    bool has_explicit_template = !chat_template_override.empty();
+    if (chat_template_override.empty()) {
+        GGML_ASSERT(model != nullptr);
+        const auto * str = llama_model_chat_template(model, /* name */ nullptr);
+        if (str) {
+            default_template_src = str;
+            has_explicit_template = true;
+        }
+        str = llama_model_chat_template(model, /* name */ "tool_use");
+        if (str) {
+            template_tool_use_src = str;
+            has_explicit_template = true;
+        }
+    } else {
+        default_template_src = chat_template_override;
+    }
+    if (default_template_src.empty() || default_template_src == "chatml") {
+        if (!template_tool_use_src.empty()) {
+            default_template_src = template_tool_use_src;
+        } else {
+            default_template_src = CHATML_TEMPLATE_SRC;
+        }
+    }
+    std::string token_bos = bos_token_override;
+    std::string token_eos = eos_token_override;
+    if (model) {
+        const auto * vocab = llama_model_get_vocab(model);
+        const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
+            if (token == LLAMA_TOKEN_NULL) {
+                if (default_template_src.find(jinja_variable_name) != std::string::npos
+                    || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
+                    LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
+                }
+                return std::string();
+            }
+            return common_token_to_piece(vocab, token, true);
+        };
+        token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
+        token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
+    }
+    common_chat_templates_ptr tmpls(new common_chat_templates());
+    tmpls->has_explicit_template = has_explicit_template;
+    try {
+        tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
+        tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
+    }
+    if (!template_tool_use_src.empty()) {
+        try {
+            tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
+        } catch (const std::exception & e) {
+            LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
+        }
+    }
+    return tmpls;
+}
 
 std::string common_chat_format_name(common_chat_format format) {
     switch (format) {
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
 
         json_error_locator() : position(0), found_error(false) {}
 
-        bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
+        bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
             this->position = position - 1;
             this->found_error = true;
             return false;
         }
-        bool null() override { return true; }
-        bool boolean(bool) override { return true; }
-        bool number_integer(number_integer_t) override { return true; }
-        bool number_unsigned(number_unsigned_t) override { return true; }
-        bool number_float(number_float_t, const string_t &) override { return true; }
-        bool string(string_t &) override { return true; }
-        bool binary(binary_t &) override { return true; }
-        bool start_object(std::size_t) override { return true; }
-        bool key(string_t &) override { return true; }
+        bool null() override { return true; } // NOLINT
+        bool boolean(bool) override { return true; } // NOLINT
+        bool number_integer(number_integer_t) override { return true; } // NOLINT
+        bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
+        bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
+        bool string(string_t &) override { return true; } // NOLINT
+        bool binary(binary_t &) override { return true; } // NOLINT
+        bool start_object(std::size_t) override { return true; } // NOLINT
+        bool key(string_t &) override { return true; } // NOLINT
         bool end_object() override { return true; }
-        bool start_array(std::size_t) override { return true; }
+        bool start_array(std::size_t) override { return true; } // NOLINT
         bool end_array() override { return true; }
     };
     json_error_locator err_loc;
@@ -187,13 +612,20 @@ static std::string apply(
     // tmpl_inputs.now = std::chrono::system_clock::now();
 
     minja::chat_template_options tmpl_opts;
-    tmpl_opts.use_bos_token = false;
-    tmpl_opts.use_eos_token = false;
-
-    return tmpl.apply(tmpl_inputs, tmpl_opts);
+    // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
+    // instead of using `chat_template_options.use_bos_token = false`, since these tokens
+    // may be needed inside the template / between messages too.
+    auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
+    if (string_starts_with(result, tmpl.bos_token())) {
+        result = result.substr(tmpl.bos_token().size());
+    }
+    if (string_ends_with(result, tmpl.eos_token())) {
+        result = result.substr(0, result.size() - tmpl.eos_token().size());
+    }
+    return result;
 }
 
-static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
 
     auto tool_call_schemas = json::array();
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
                 {"required", json::array({"tool_call"})},
             };
     const auto schema =
-        inputs.tool_choice != "required"
+        inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
             ? json {
                 {"anyOf", json::array({
                     tool_call,
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
     return result;
 }
 
-static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         auto schemas = json::array();
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
     return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
 }
 
-static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         auto schemas = json::array();
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
     const auto & parameters_required = parameters.at("required");
     for (const auto & prop : expected_properties) {
         if (!parameters_properties.contains(prop)) {
-            throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
+            throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
         }
         if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
-            throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
+            throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
         }
     }
     if (parameters_properties.size() != expected_properties.size()) {
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
     }
 }
 
-static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
+static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
     auto builtin_tools = json::array();
     common_chat_params data;
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
 
         auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
-            if (name == "wolfram_alpha") {
+            if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
                 // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
-                expect_tool_parameters(name, parameters, {"query"});
-            } else if (name == "web_search" || name == "brave_search") {
                 // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
                 expect_tool_parameters(name, parameters, {"query"});
             } else if (name == "python" || name == "code_interpreter") {
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
 
             std::vector<std::string> kvs;
             for (const auto & [key, value] : parameters.at("properties").items()) {
-                kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
+                kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
             }
 
             tool_rules.push_back(
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
             auto arg_value_str = raw_args.substr(it_eq + 1);
             auto arg_value = json::parse(arg_value_str);
 
-            return {
-                /* .role = */ "assistant",
-                /* .content = */ match.prefix().str(),
-                /* .tool_calls = */ {
-                    {
-                        /* .name = */ match[1],
-                        /* .arguments = */ (json {
-                            {arg_name, arg_value},
-                        }).dump(),
-                        /* .id = */ "",
-                    },
-                },
-            };
+            common_chat_msg msg;
+            msg.role = "assistant";
+            msg.content = match.prefix().str();
+            msg.tool_calls.push_back({
+                /* .name = */ name,
+                /* .arguments = */ (json {
+                    {arg_name, arg_value},
+                }).dump(),
+                /* .id = */ "",
+            });
+            return msg;
         }
     }
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
 }
 
-static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             std::vector<std::string> tool_rules;
             foreach_function(inputs.tools, [&](const json & tool) {
                 const auto & function = tool.at("function");
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
                 auto args_rule = builder.add_schema(name + "-args", parameters);
                 tool_rules.push_back(builder.add_rule(name + "-call",
                     "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
     return msg;
 }
 
-static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
-    fprintf(stderr, "%s\n", __func__);
+static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
+    LOG_DBG("%s\n", __func__);
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
         {"datetime", "Jan 29 2025 13:00:00 GMT"},
         {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
     });
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required";
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             auto schemas = json::array();
             foreach_function(inputs.tools, [&](const json & tool) {
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
     return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
 }
 
-static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
     // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
     if (inputs.tools.is_array() && !inputs.tools.empty()) {
-        data.grammar_lazy = inputs.tool_choice != "required";
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
         data.grammar = build_grammar([&](const common_grammar_builder & builder) {
             std::vector<std::string> first_tool_rules;
             std::vector<std::string> subsequent_tool_rules;
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
                 const auto & function = tool.at("function");
                 std::string name = function.at("name");
                 auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
                 auto args_rule = builder.add_schema(name + "-args", parameters);
                 first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
                 subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
     }
 }
 
-static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
     // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
     common_chat_params data;
     json tools = inputs.tools.is_null() ? inputs.tools : json::array();
     std::string python_code_argument_name;
     auto has_raw_python = false;
 
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
                     throw std::runtime_error("Missing type in python tool");
                 }
                 has_raw_python = true;
-                auto type = parameters.at("type");
+                const auto & type = parameters.at("type");
                 if (type == "object") {
                     auto properties = parameters.at("properties");
                     for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
     std::smatch match;
     if (std::regex_search(input, match, python_tag_regex)) {
         auto code = match[1].str();
-        return {
-            /* .role = */ "assistant",
-            /* .content = */ match.prefix().str(),
-            /* .tool_calls = */ {
-                {
-                    /* .name = */ "python",
-                    /* .arguments = */ (json {{"code", code}}).dump(),
-                    /* .id = */ "",
-                },
-            }
-        };
+        common_chat_msg msg;
+        msg.role = "assistant";
+        msg.content = match.prefix().str();
+        msg.tool_calls.push_back({
+            /* .name = */ "python",
+            /* .arguments = */ (json {{"code", code}}).dump(),
+            /* .id = */ "",
+        });
+        return msg;
     }
     static std::regex function_regex(R"(<function=(\w+)>)");
     static std::regex close_regex(R"(</function>)");
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
 }
 
-static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
-    data.grammar_lazy = inputs.tool_choice != "required";
+    data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     data.grammar = build_grammar([&](const common_grammar_builder & builder) {
         std::vector<std::string> tool_rules;
         foreach_function(inputs.tools, [&](const json & tool) {
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
         std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
         std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
 
+        common_chat_msg msg;
+        msg.role = "assistant";
+
         auto end = input.end();
         std::sregex_iterator rend;
         std::sregex_iterator rit(input.begin(), end, start_pattern);
         if (rit == rend) {
-            return {
-                /* .role = */ "assistant",
-                /* .content = */ input,
-                /* .tool_calls = */ {},
-            };
+            msg.content = input;
+            return msg;
         }
 
-        common_chat_msg result;
-        result.role = "assistant";
-        result.content = rit->prefix();
+        msg.content = rit->prefix();
 
         auto it = rit->suffix().first;
         while (it != end) {
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
                 throw std::runtime_error("Failed to parse json tool call");
             }
             const auto & arguments = call.at("arguments");
-            result.tool_calls.push_back({
+            msg.tool_calls.push_back({
                 call.at("name"),
                 arguments.dump(),
                 // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
                 break;
             }
         }
-        return result;
+        return msg;
     } catch (const std::exception & e) {
-        return {
-            /* .role = */ "assistant",
-            /* .content = */ input,
-            /* .tool_calls = */ {},
-        };
+        LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
+        common_chat_msg msg;
+        msg.role = "assistant";
+        msg.content = input;
+        return msg;
     }
 }
 
-static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
     data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
     return data;
 }
 
-common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
+static common_chat_params common_chat_templates_apply_jinja(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    templates_params params;
+    params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
+    const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
+        ? *tmpls->template_tool_use
+        : *tmpls->template_default;
     const auto & src = tmpl.source();
     const auto & caps = tmpl.original_caps();
+    params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
+    params.add_generation_prompt = inputs.add_generation_prompt;
+    params.extract_reasoning = inputs.extract_reasoning;
+    params.tool_choice = inputs.tool_choice;
+    params.grammar = inputs.grammar;
+    if (!inputs.json_schema.empty()) {
+        params.json_schema = json::parse(inputs.json_schema);
+    }
 
-    if (inputs.tools.is_array()) {
-        if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
+    if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
+        LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
+        params.parallel_tool_calls = false;
+    } else {
+        params.parallel_tool_calls = inputs.parallel_tool_calls;
+    }
+
+    if (params.tools.is_array()) {
+        if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
             throw std::runtime_error("Cannot specify grammar with tools");
         }
         if (caps.supports_tool_calls && !caps.supports_tools) {
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
     }
 
     // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
-    if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
-        return common_chat_params_init_deepseek_r1(tmpl, inputs);
+    if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
+        return common_chat_params_init_deepseek_r1(tmpl, params);
     }
 
     // Command R7B: : use handler in all cases except json schema (thinking / tools).
-    if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
-        return common_chat_params_init_command_r7b(tmpl, inputs);
+    if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
+        return common_chat_params_init_command_r7b(tmpl, params);
     }
 
     // Use generic handler when mixing tools + JSON schema.
     // TODO: support that mix in handlers below.
-    if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
-        return common_chat_params_init_generic(tmpl, inputs);
+    if ((params.tools.is_array() && params.json_schema.is_object())) {
+        return common_chat_params_init_generic(tmpl, params);
     }
 
     // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
     if (src.find(">>>all") != std::string::npos) {
-        return common_chat_params_init_functionary_v3_2(tmpl, inputs);
+        return common_chat_params_init_functionary_v3_2(tmpl, params);
     }
 
     // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
     if (src.find(" functools[") != std::string::npos) {
-        return common_chat_params_init_firefunction_v2(tmpl, inputs);
+        return common_chat_params_init_firefunction_v2(tmpl, params);
     }
 
     // Plain handler (no tools)
-    if (inputs.tools.is_null() || inputs.tool_choice == "none") {
-        return common_chat_params_init_without_tools(tmpl, inputs);
+    if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
+        return common_chat_params_init_without_tools(tmpl, params);
     }
 
     // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
     if (src.find("<tool_call>") != std::string::npos) {
-        return common_chat_params_init_hermes_2_pro(tmpl, inputs);
+        return common_chat_params_init_hermes_2_pro(tmpl, params);
     }
 
     // Functionary v3.1 (w/ tools)
     if (src.find("<|start_header_id|>") != std::string::npos
         && src.find("<function=") != std::string::npos) {
-        return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
+        return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
     }
 
     // Llama 3.1, 3.2, 3.3 (w/ tools)
     if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
         auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
-        return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
+        return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
     }
 
     // Mistral Nemo (w/ tools)
     if (src.find("[TOOL_CALLS]") != std::string::npos) {
-        return common_chat_params_init_mistral_nemo(tmpl, inputs);
+        return common_chat_params_init_mistral_nemo(tmpl, params);
     }
 
     // Generic fallback
-    return common_chat_params_init_generic(tmpl, inputs);
+    return common_chat_params_init_generic(tmpl, params);
+}
+
+// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
+static common_chat_params common_chat_templates_apply_legacy(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    int alloc_size = 0;
+    std::vector<llama_chat_message> chat;
+    std::vector<std::string> contents;
+    for (const auto & msg : inputs.messages) {
+        auto content = msg.content;
+        for (const auto & part : msg.content_parts) {
+            if (part.type != "text") {
+                LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
+                continue;
+            }
+            if (!content.empty()) {
+                content += "\n";;
+            }
+            content += part.text;
+        }
+        contents.emplace_back(std::move(content));
+    }
+    for (size_t i = 0; i < contents.size(); ++i) {
+        const auto & msg = inputs.messages[i];
+        const auto & content = contents[i];
+        chat.push_back({msg.role.c_str(), content.c_str()});
+        alloc_size += (msg.role.size() + content.size()) * 1.25;
+    }
+
+    std::vector<char> buf(alloc_size);
+
+    // run the first time to get the total output length
+    const auto & src = tmpls->template_default->source();
+    int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+
+    // error: chat template is not supported
+    if (res < 0) {
+        // if the custom "tmpl" is not supported, we throw an error
+        // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+        throw std::runtime_error("this custom template is not supported");
+    }
+
+    // if it turns out that our buffer is too small, we resize it
+    if ((size_t) res > buf.size()) {
+        buf.resize(res);
+        res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
+    }
+
+    common_chat_params params;
+    params.prompt = std::string(buf.data(), res);
+    if (!inputs.json_schema.empty()) {
+        params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
+    } else {
+        params.grammar = inputs.grammar;
+    }
+    return params;
+}
+
+common_chat_params common_chat_templates_apply(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs)
+{
+    GGML_ASSERT(tmpls != nullptr);
+    return inputs.use_jinja
+        ? common_chat_templates_apply_jinja(tmpls, inputs)
+        : common_chat_templates_apply_legacy(tmpls, inputs);
 }
 
 static common_chat_msg common_chat_parse_content_only(const std::string & input) {
-    return {
-        /* .role = */ "assistant",
-        /* .content = */ input,
-        /* .tool_calls = */ {},
-    };
+    common_chat_msg msg;
+    msg.role = "assistant";
+    msg.content = input;
+    return msg;
 }
 
 common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {

+ 134 - 0
common/chat.h

@@ -0,0 +1,134 @@
+// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
+
+#pragma once
+
+#include "common.h"
+#include <string>
+#include <vector>
+
+struct common_chat_templates;
+
+struct common_chat_tool_call {
+    std::string name;
+    std::string arguments;
+    std::string id;
+};
+
+struct common_chat_msg_content_part {
+    std::string type;
+    std::string text;
+};
+
+struct common_chat_msg {
+    std::string role;
+    std::string content;
+    std::vector<common_chat_msg_content_part> content_parts = {};
+    std::vector<common_chat_tool_call> tool_calls = {};
+    std::string reasoning_content;
+    std::string tool_name;
+    std::string tool_call_id;
+};
+
+struct common_chat_tool {
+    std::string name;
+    std::string description;
+    std::string parameters;
+};
+
+enum common_chat_tool_choice {
+    COMMON_CHAT_TOOL_CHOICE_AUTO,
+    COMMON_CHAT_TOOL_CHOICE_REQUIRED,
+    COMMON_CHAT_TOOL_CHOICE_NONE,
+};
+
+enum common_chat_format {
+    COMMON_CHAT_FORMAT_CONTENT_ONLY,
+    COMMON_CHAT_FORMAT_GENERIC,
+    COMMON_CHAT_FORMAT_MISTRAL_NEMO,
+    COMMON_CHAT_FORMAT_LLAMA_3_X,
+    COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
+    COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
+    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
+    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+    COMMON_CHAT_FORMAT_HERMES_2_PRO,
+    COMMON_CHAT_FORMAT_COMMAND_R7B,
+    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
+
+    COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
+};
+
+struct common_chat_templates_inputs {
+    std::vector<common_chat_msg> messages;
+    std::string grammar;
+    std::string json_schema;
+    bool add_generation_prompt = true;
+    bool use_jinja = true;
+    // Parameters below only supported when use_jinja is true
+    std::vector<common_chat_tool> tools;
+    common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
+    bool parallel_tool_calls = false;
+    bool extract_reasoning     = true;
+};
+
+struct common_chat_params {
+    common_chat_format                  format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    std::string                         prompt;
+    std::string                         grammar;
+    bool                                grammar_lazy = false;
+    std::vector<common_grammar_trigger> grammar_triggers;
+    std::vector<std::string>            preserved_tokens;
+    std::vector<std::string>            additional_stops;
+};
+
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
+
+void common_chat_templates_free(struct common_chat_templates * tmpls);
+
+struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
+
+typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
+
+common_chat_templates_ptr common_chat_templates_init(
+                                    const struct llama_model * model,
+                                           const std::string & chat_template_override,
+                                           const std::string & bos_token_override = "",
+                                           const std::string & eos_token_override = "");
+
+bool         common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
+const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
+
+
+struct common_chat_params      common_chat_templates_apply(
+    const struct common_chat_templates * tmpls,
+    const struct common_chat_templates_inputs & inputs);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string common_chat_format_single(
+        const struct common_chat_templates * tmpls,
+        const std::vector<common_chat_msg> & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass,
+        bool use_jinja);
+
+// Returns an example of formatted chat
+std::string common_chat_format_example(
+    const struct common_chat_templates * tmpls,
+    bool use_jinja);
+
+std::string               common_chat_format_name(common_chat_format format);
+common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);
+
+common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
+
+// Parses a JSON array of messages in OpenAI's chat completion API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
+template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
+
+// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
+// T can be std::string containing JSON or nlohmann::ordered_json
+template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
+template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);

+ 0 - 55
common/chat.hpp

@@ -1,55 +0,0 @@
-// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
-
-#pragma once
-
-#include "common.h"
-#include <json.hpp>
-#include <optional>
-#include <string>
-#include <vector>
-
-using json = nlohmann::ordered_json;
-
-struct common_chat_inputs {
-    json messages;
-    json tools;
-    json tool_choice;
-    json json_schema;
-    bool parallel_tool_calls;
-    bool stream;
-    std::string grammar;
-    bool add_generation_prompt = true;
-    bool extract_reasoning     = true;
-};
-
-enum common_chat_format {
-    COMMON_CHAT_FORMAT_CONTENT_ONLY,
-    COMMON_CHAT_FORMAT_GENERIC,
-    COMMON_CHAT_FORMAT_MISTRAL_NEMO,
-    COMMON_CHAT_FORMAT_LLAMA_3_X,
-    COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
-    COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
-    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
-    COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
-    COMMON_CHAT_FORMAT_HERMES_2_PRO,
-    COMMON_CHAT_FORMAT_COMMAND_R7B,
-    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
-
-    COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
-};
-
-struct common_chat_params {
-    common_chat_format                  format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
-    json                                prompt;
-    std::string                         grammar;
-    bool                                grammar_lazy = false;
-    std::vector<common_grammar_trigger> grammar_triggers;
-    std::vector<std::string>            preserved_tokens;
-    std::vector<std::string>            additional_stops;
-};
-
-struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
-std::string               common_chat_format_name(common_chat_format format);
-common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);

+ 0 - 170
common/common.cpp

@@ -12,8 +12,6 @@
 #include "json.hpp"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
-#include "chat.hpp"
-#include "chat-template.hpp"
 
 #include <algorithm>
 #include <cinttypes>
@@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
     return text;
 }
 
-//
-// Chat template utils
-//
-
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
-    if (use_jinja) {
-        try {
-            auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
-            common_chat_inputs inputs;
-            inputs.messages = json::array({{
-                {"role", "user"},
-                {"content", "test"},
-            }});
-            common_chat_params_init(chat_template, inputs);
-            return true;
-        } catch (const std::exception & e) {
-            LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
-            return false;
-        }
-    }
-    llama_chat_message chat[] = {{"user", "test"}};
-    const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
-    return res >= 0;
-}
-
-std::string common_chat_apply_template(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & msgs,
-        bool add_ass,
-        bool use_jinja) {
-    if (use_jinja) {
-        auto messages = json::array();
-        for (const auto & msg : msgs) {
-            messages.push_back({{"role", msg.role}, {"content", msg.content}});
-        }
-        common_chat_inputs inputs;
-        inputs.messages = messages;
-        inputs.add_generation_prompt = add_ass;
-        return common_chat_params_init(tmpl, inputs).prompt;
-    }
-
-    int alloc_size = 0;
-    std::vector<llama_chat_message> chat;
-    for (const auto & msg : msgs) {
-        chat.push_back({msg.role.c_str(), msg.content.c_str()});
-        alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
-    }
-
-    std::vector<char> buf(alloc_size);
-
-    // run the first time to get the total output length
-    int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-
-    // error: chat template is not supported
-    if (res < 0) {
-        // if the custom "tmpl" is not supported, we throw an error
-        // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
-        throw std::runtime_error("this custom template is not supported");
-    }
-
-    // if it turns out that our buffer is too small, we resize it
-    if ((size_t) res > buf.size()) {
-        buf.resize(res);
-        res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-    }
-
-    std::string formatted_chat(buf.data(), res);
-    return formatted_chat;
-}
-
-std::string common_chat_format_single(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & past_msg,
-        const common_chat_msg & new_msg,
-        bool add_ass,
-        bool use_jinja) {
-    std::ostringstream ss;
-    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
-    std::vector<common_chat_msg> chat_new(past_msg);
-    // if the past_msg ends with a newline, we must preserve it in the formatted version
-    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
-        ss << "\n";
-    };
-    // format chat with new_msg
-    chat_new.push_back(new_msg);
-    auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
-    // get the diff part
-    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
-    return ss.str();
-}
-
-std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
-    std::vector<common_chat_msg> msgs = {
-        {"system",    "You are a helpful assistant", {}},
-        {"user",      "Hello", {}},
-        {"assistant", "Hi there", {}},
-        {"user",      "How are you?", {}},
-    };
-    return common_chat_apply_template(tmpl, msgs, true, use_jinja);
-}
-
-#define CHATML_TEMPLATE_SRC \
-    "{%- for message in messages -%}\n" \
-    "  {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
-    "{%- endfor -%}\n" \
-    "{%- if add_generation_prompt -%}\n" \
-    "  {{- '<|im_start|>assistant\n' -}}\n" \
-    "{%- endif -%}"
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
-{
-    std::string default_template_src;
-    std::string template_tool_use_src;
-
-    bool has_explicit_template = !chat_template_override.empty();
-    if (chat_template_override.empty()) {
-        auto str = llama_model_chat_template(model, /* name */ nullptr);
-        if (str) {
-            default_template_src = str;
-            has_explicit_template = true;
-        }
-        str = llama_model_chat_template(model, /* name */ "tool_use");
-        if (str) {
-            template_tool_use_src = str;
-            has_explicit_template = true;
-        }
-    } else {
-        default_template_src = chat_template_override;
-    }
-    if (default_template_src.empty() || default_template_src == "chatml") {
-        if (!template_tool_use_src.empty()) {
-            default_template_src = template_tool_use_src;
-        } else {
-            default_template_src = CHATML_TEMPLATE_SRC;
-        }
-    }
-    auto vocab = llama_model_get_vocab(model);
-    const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
-        if (token == LLAMA_TOKEN_NULL) {
-            if (default_template_src.find(jinja_variable_name) != std::string::npos
-                || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
-                LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
-            }
-            return std::string();
-        } else {
-            return common_token_to_piece(vocab, token, true);
-        }
-    };
-    auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
-    auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
-    try {
-        return {
-            has_explicit_template,
-            std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
-            template_tool_use_src.empty()
-                ? nullptr
-                : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
-        };
-    } catch (const std::exception & e) {
-        LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
-        return {
-            has_explicit_template,
-            std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
-            nullptr,
-        };
-    }
-}
-
 //
 // KV cache utils
 //

+ 0 - 56
common/common.h

@@ -616,62 +616,6 @@ std::string common_detokenize(
         const std::vector<llama_token> & tokens,
                                   bool   special = true);
 
-//
-// Chat template utils
-//
-
-struct common_tool_call {
-    std::string name;
-    std::string arguments;
-    std::string id;
-};
-
-// same with llama_chat_message, but uses std::string
-struct common_chat_msg {
-    std::string role;
-    std::string content;
-    std::vector<common_tool_call> tool_calls;
-    std::string reasoning_content = "";
-};
-
-// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
-bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
-
-namespace minja {
-    class chat_template;
-}
-
-typedef minja::chat_template common_chat_template;
-
-struct common_chat_templates {
-    bool has_explicit_template; // Model had builtin template or template overridde was specified.
-    std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
-    std::unique_ptr<common_chat_template> template_tool_use;
-};
-
-// CPP wrapper for llama_chat_apply_template
-// If the built-in template is not supported, we default to chatml
-// If the custom "tmpl" is not supported, we throw an error
-std::string common_chat_apply_template(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & chat,
-        bool add_ass,
-        bool use_jinja);
-
-// Format single message, while taking into account the position of that message in chat history
-std::string common_chat_format_single(
-        const common_chat_template & tmpl,
-        const std::vector<common_chat_msg> & past_msg,
-        const common_chat_msg & new_msg,
-        bool add_ass,
-        bool use_jinja);
-
-// Returns an example of formatted chat
-std::string common_chat_format_example(
-    const common_chat_template & tmpl, bool use_jinja);
-
-common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
-
 //
 // KV cache utils
 //

+ 0 - 0
common/chat-template.hpp → common/minja/chat-template.hpp


+ 0 - 0
common/minja.hpp → common/minja/minja.hpp


+ 16 - 11
examples/main/main.cpp

@@ -4,7 +4,7 @@
 #include "log.h"
 #include "sampling.h"
 #include "llama.h"
-#include "chat-template.hpp"
+#include "chat.h"
 
 #include <cstdio>
 #include <cstring>
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
     }
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
-    auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
+    auto chat_templates = common_chat_templates_init(model, params.chat_template);
 
     LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
 
@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
     }
 
     // auto enable conversation mode if chat template is available
-    const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
+    const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
     if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
         if (has_chat_template) {
             LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
     // print chat template example in conversation mode
     if (params.conversation_mode) {
         if (params.enable_chat_template) {
-            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
+            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
         } else {
             LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
         }
@@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd_inp;
 
     auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
-        common_chat_msg new_msg{role, content, {}};
-        auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
-        chat_msgs.push_back({role, content, {}});
+        common_chat_msg new_msg;
+        new_msg.role = role;
+        new_msg.content = content;
+        auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
+        chat_msgs.push_back(new_msg);
         LOG_DBG("formatted: '%s'\n", formatted.c_str());
         return formatted;
     };
@@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
 
                 // check for reverse prompt using special tokens
                 llama_token last_token = common_sampler_last(smpl);
-                if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
-                    if (params.interactive) {
-                        is_interacting = true;
+                for (auto token : antiprompt_token) {
+                    if (token == last_token) {
+                        if (params.interactive) {
+                            is_interacting = true;
+                        }
+                        is_antiprompt = true;
+                        break;
                     }
-                    is_antiprompt = true;
                 }
 
                 if (is_antiprompt) {

+ 24 - 46
examples/run/run.cpp

@@ -24,7 +24,7 @@
 #include <string>
 #include <vector>
 
-#include "chat-template.hpp"
+#include "chat.h"
 #include "common.h"
 #include "json.hpp"
 #include "linenoise.cpp/linenoise.h"
@@ -557,7 +557,7 @@ class LlamaData {
     llama_model_ptr                 model;
     llama_sampler_ptr               sampler;
     llama_context_ptr               context;
-    std::vector<llama_chat_message> messages;
+    std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
     std::list<std::string>          msg_strs;
     std::vector<char>               fmtted;
 
@@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
 }
 
 // Function to apply the chat template and resize `formatted` if needed
-static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
-    if (use_jinja) {
-        json messages = json::array();
-        for (const auto & msg : llama_data.messages) {
-            messages.push_back({
-                {"role", msg.role},
-                {"content", msg.content},
-            });
-        }
-        try {
-            minja::chat_template_inputs tmpl_inputs;
-            tmpl_inputs.messages = messages;
-            tmpl_inputs.add_generation_prompt = append;
-
-            minja::chat_template_options tmpl_opts;
-            tmpl_opts.use_bos_token = false;
-            tmpl_opts.use_eos_token = false;
-
-            auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
-            llama_data.fmtted.resize(result.size() + 1);
-            memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
-            return result.size();
-        } catch (const std::exception & e) {
-            printe("failed to render the chat template: %s\n", e.what());
-            return -1;
-        }
-    }
-    int result = llama_chat_apply_template(
-        tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
-        append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
-    if (append && result > static_cast<int>(llama_data.fmtted.size())) {
-        llama_data.fmtted.resize(result);
-        result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
-                                           llama_data.messages.size(), append, llama_data.fmtted.data(),
-                                           llama_data.fmtted.size());
-    }
-
-    return result;
+static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
+    common_chat_templates_inputs inputs;
+    for (const auto & msg : llama_data.messages) {
+        common_chat_msg cmsg;
+        cmsg.role    = msg.role;
+        cmsg.content = msg.content;
+        inputs.messages.push_back(cmsg);
+    }
+    inputs.add_generation_prompt = append;
+    inputs.use_jinja = use_jinja;
+
+    auto chat_params = common_chat_templates_apply(tmpls, inputs);
+    // TODO: use other params for tool calls.
+    auto result = chat_params.prompt;
+    llama_data.fmtted.resize(result.size() + 1);
+    memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
+    return result.size();
 }
 
 // Function to tokenize the prompt
@@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
 }
 
 // Helper function to apply the chat template and handle errors
-static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
-    const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
+static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
+    const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
     if (new_len < 0) {
         printe("failed to apply the chat template\n");
         return -1;
@@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
 static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
-    auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
-    GGML_ASSERT(chat_templates.template_default);
+    auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
@@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
 
         add_message("user", user.empty() ? user_input : user, llama_data);
         int new_len;
-        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
+        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
             return 1;
         }
 
@@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
         }
 
         add_message("assistant", response, llama_data);
-        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
+        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
             return 1;
         }
     }

+ 17 - 46
examples/server/server.cpp

@@ -329,9 +329,6 @@ struct server_task {
         }
 
         // process "json_schema" and "grammar"
-        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
-            throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
-        }
         if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
                 auto schema                  = json_value(data, "json_schema", json::object());
@@ -1807,7 +1804,7 @@ struct server_context {
     // Necessary similarity of prompt for slot selection
     float slot_prompt_similarity = 0.0f;
 
-    common_chat_templates chat_templates;
+    common_chat_templates_ptr chat_templates;
 
     ~server_context() {
         // Clear any sampling context
@@ -1891,45 +1888,17 @@ struct server_context {
             llama_init_dft.context.reset();
         }
 
-        if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
+        chat_templates = common_chat_templates_init(model, params_base.chat_template);
+        try {
+            common_chat_format_example(chat_templates.get(), params.use_jinja);
+        } catch (const std::exception & e) {
             SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
-            chat_templates = common_chat_templates_from_model(model, "chatml");
-        } else {
-            chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
+            chat_templates = common_chat_templates_init(model, "chatml");
         }
-        GGML_ASSERT(chat_templates.template_default.get() != nullptr);
 
         return true;
     }
 
-    bool validate_builtin_chat_template(bool use_jinja) const {
-        llama_chat_message chat[] = {{"user", "test"}};
-
-        if (use_jinja) {
-            auto templates = common_chat_templates_from_model(model, "");
-            common_chat_inputs inputs;
-            inputs.messages = json::array({{
-                {"role", "user"},
-                {"content", "test"},
-            }});
-            GGML_ASSERT(templates.template_default);
-            try {
-                common_chat_params_init(*templates.template_default, inputs);
-                if (templates.template_tool_use) {
-                    common_chat_params_init(*templates.template_tool_use, inputs);
-                }
-                return true;
-            } catch (const std::exception & e) {
-                SRV_ERR("failed to apply template: %s\n", e.what());
-                return false;
-            }
-        } else {
-            const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
-            const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
-            return chat_res > 0;
-        }
-    }
-
     void init() {
         const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
 
@@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model },
-            { "chat_template",               ctx_server.chat_templates.template_default->source() },
-            { "bos_token",                   ctx_server.chat_templates.template_default->bos_token() },
-            { "eos_token",                   ctx_server.chat_templates.template_default->eos_token() },
+            { "chat_template",               common_chat_templates_source(ctx_server.chat_templates.get()) },
+            { "bos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
+            { "eos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
             { "build_info",                  build_info },
         };
-        if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
-            data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
+        if (ctx_server.params_base.use_jinja) {
+            if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
+                data["chat_template_tool_use"] = tool_use_src;
+            }
         }
 
         res_ok(res, data);
@@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
         }
 
         auto body = json::parse(req.body);
-        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
 
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
@@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
     // same with handle_chat_completions, but without inference part
     const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
         auto body = json::parse(req.body);
-        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
+        json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
         res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
     };
 
@@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) {
 
     // print sample chat example to make it clear which template is used
     LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
-        ctx_server.chat_templates.template_default->source().c_str(),
-        common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
+        common_chat_templates_source(ctx_server.chat_templates.get()),
+        common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
 
     ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
         ctx_server.process_single_task(task);

+ 44 - 1
examples/server/tests/unit/test_chat_completion.py

@@ -21,6 +21,8 @@ def create_server():
         (None, "Book", "What is the best book", 8, "^ blue",                    23, 8, "length", True, "This is not a chat template, it is"),
         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
         ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
+        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
+        (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
     ]
 )
 def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
     assert res.body["usage"]["completion_tokens"] == n_predicted
     choice = res.body["choices"][0]
     assert "assistant" == choice["message"]["role"]
-    assert match_regex(re_content, choice["message"]["content"])
+    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
     assert choice["finish_reason"] == finish_reason
 
 
@@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
         assert "error" in res.body
 
 
+@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
+    (False, {"const": "42"}, 6, "\"42\""),
+    (True, {"const": "42"}, 6, "\"42\""),
+])
+def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
+    global server
+    server.jinja = jinja
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": n_predicted,
+        "messages": [
+            {"role": "system", "content": "You are a coding assistant."},
+            {"role": "user", "content": "Write an example"},
+        ],
+        "json_schema": json_schema,
+    })
+    assert res.status_code == 200, f'Expected 200, got {res.status_code}'
+    choice = res.body["choices"][0]
+    assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
+
+
+@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
+    (False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+    (True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
+])
+def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
+    global server
+    server.jinja = jinja
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "max_tokens": n_predicted,
+        "messages": [
+            {"role": "user", "content": "Does not matter what I say, does it?"},
+        ],
+        "grammar": grammar,
+    })
+    assert res.status_code == 200, res.body
+    choice = res.body["choices"][0]
+    assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
+
+
 @pytest.mark.parametrize("messages", [
     None,
     "string",

+ 5 - 5
examples/server/tests/unit/test_tool_call.py

@@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
     (None,                                           128,  "bartowski/functionary-small-v3.2-GGUF:Q8_0",        ("meetkai/functionary-medium-v3.2", None)),
     (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
     (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  None),
-    ("^> 0.56$",                                     128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  "chatml"),
+    (None,                                           128,  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M",  "chatml"),
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
 
     # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
-    ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*",  8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    ("[\\s\\S]*?\\*\\*0\\.5\\*\\*",                  8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 ])
 def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
     global server
@@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
             {
                 "role": "tool",
                 "name": "calculate",
-                "content": 0.55644242476,
+                "content": "0.55644242476",
                 "tool_call_id": "call_6789"
             }
         ],
@@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     (128,  None,        "^The sum of 102 and 7 is 109.*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
 
     (1024, 'deepseek',  "To find the sum of.*",                                 "I need to calculate the sum of 102 and 7.*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (1024, 'none',      "<think>\n?I need[\\s\\S]*?</think>\n?To find.*",       None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'none',      "^I need[\\s\\S]*?</think>\n?To find.*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
     (1024, 'deepseek',  "To find the sum of.*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
 ])

+ 41 - 85
examples/server/utils.hpp

@@ -12,9 +12,7 @@
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
-#include "minja.hpp"
-#include "chat.hpp"
-#include "chat-template.hpp"
+#include "chat.h"
 
 #include <random>
 #include <sstream>
@@ -347,41 +345,6 @@ static llama_tokens format_infill(
     return embd_inp;
 }
 
-// Format given chat. If tmpl is empty, we take the template from model metadata
-inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
-    std::vector<common_chat_msg> chat;
-
-    for (size_t i = 0; i < messages.size(); ++i) {
-        const auto & curr_msg = messages[i];
-
-        std::string role = json_value(curr_msg, "role", std::string(""));
-
-        std::string content;
-        if (curr_msg.contains("content")) {
-            if (curr_msg["content"].is_string()) {
-                content = curr_msg["content"].get<std::string>();
-            } else if (curr_msg["content"].is_array()) {
-                for (const auto & part : curr_msg["content"]) {
-                    if (part.contains("text")) {
-                        content += "\n" + part["text"].get<std::string>();
-                    }
-                }
-            } else {
-                throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
-            }
-        } else {
-            throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
-        }
-
-        chat.push_back({role, content, /* tool_calls= */ {}});
-    }
-
-    const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
-    LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
-
-    return formatted_chat;
-}
-
 //
 // base64 utils (TODO: move to common in the future)
 //
@@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
     const json & body, /* openai api json semantics */
     bool use_jinja,
     common_reasoning_format reasoning_format,
-    const common_chat_templates & chat_templates)
+    const struct common_chat_templates * tmpls)
 {
     json llama_params;
-    const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
-        ? *chat_templates.template_tool_use
-        : *chat_templates.template_default;
 
     auto tools = json_value(body, "tools", json());
     auto stream = json_value(body, "stream", false);
@@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse(
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
 
+    auto json_schema = json_value(body, "json_schema", json());
+    auto grammar = json_value(body, "grammar", std::string());
+    if (!json_schema.is_null() && !grammar.empty()) {
+        throw std::runtime_error("Cannot use both json_schema and grammar");
+    }
+
     // Handle "response_format" field
     if (body.contains("response_format")) {
         json response_format      = json_value(body, "response_format", json::object());
         std::string response_type = json_value(response_format, "type", std::string());
         if (response_type == "json_object") {
-            llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+            json_schema = json_value(response_format, "schema", json::object());
         } else if (response_type == "json_schema") {
             json json_schema = json_value(response_format, "json_schema", json::object());
-            llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
+            json_schema = json_value(json_schema, "schema", json::object());
         } else if (!response_type.empty() && response_type != "text") {
             throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
         }
     }
 
+    common_chat_templates_inputs inputs;
+    inputs.messages              = common_chat_msgs_parse_oaicompat(body.at("messages"));
+    inputs.tools                 = common_chat_tools_parse_oaicompat(tools);
+    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+    inputs.json_schema           = json_schema.is_null() ? "" : json_schema.dump();
+    inputs.grammar               = grammar;
+    inputs.add_generation_prompt = true;
+    inputs.use_jinja             = use_jinja;
+    inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
+    inputs.extract_reasoning     = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+    if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
+        throw std::runtime_error("Cannot use custom grammar constraints with tools.");
+    }
+
     // Apply chat template to the list of messages
-    if (use_jinja) {
-        auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
-        if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
-            throw std::runtime_error("Invalid tool_choice: " + tool_choice);
-        }
-        if (tool_choice != "none" && llama_params.contains("grammar")) {
-            throw std::runtime_error("Cannot use custom grammar constraints with tools.");
-        }
-        common_chat_inputs inputs;
-        inputs.extract_reasoning   = reasoning_format != COMMON_REASONING_FORMAT_NONE;
-        inputs.messages            = body.at("messages");
-        inputs.tools               = tools;
-        inputs.tool_choice         = tool_choice;
-        inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
-        if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
-            LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
-            inputs.parallel_tool_calls = false;
-        }
-        inputs.stream = stream;
-        // TODO: support mixing schema w/ tools beyond generic format.
-        inputs.json_schema = json_value(llama_params, "json_schema", json());
-        auto chat_params = common_chat_params_init(tmpl, inputs);
-
-        llama_params["chat_format"] = static_cast<int>(chat_params.format);
-        llama_params["prompt"] = chat_params.prompt;
-        llama_params["grammar"] = chat_params.grammar;
-        llama_params["grammar_lazy"] = chat_params.grammar_lazy;
-        auto grammar_triggers = json::array();
-        for (const auto & trigger : chat_params.grammar_triggers) {
-            grammar_triggers.push_back({
-                {"word", trigger.word},
-                {"at_start", trigger.at_start},
-            });
-        }
-        llama_params["grammar_triggers"] = grammar_triggers;
-        llama_params["preserved_tokens"] = chat_params.preserved_tokens;
-        for (const auto & stop : chat_params.additional_stops) {
-            llama_params["stop"].push_back(stop);
-        }
-    } else {
-        llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
+    auto chat_params = common_chat_templates_apply(tmpls, inputs);
+
+    llama_params["chat_format"]      = static_cast<int>(chat_params.format);
+    llama_params["prompt"]           = chat_params.prompt;
+    llama_params["grammar"]          = chat_params.grammar;
+    llama_params["grammar_lazy"]     = chat_params.grammar_lazy;
+    auto grammar_triggers = json::array();
+    for (const auto & trigger : chat_params.grammar_triggers) {
+        grammar_triggers.push_back({
+            {"word", trigger.word},
+            {"at_start", trigger.at_start},
+        });
+    }
+    llama_params["grammar_triggers"] = grammar_triggers;
+    llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+    for (const auto & stop : chat_params.additional_stops) {
+        llama_params["stop"].push_back(stop);
     }
 
     // Handle "n" field

+ 32 - 22
tests/test-chat-template.cpp

@@ -1,13 +1,14 @@
 #include <string>
 #include <vector>
 #include <sstream>
+#include <regex>
 
 #undef NDEBUG
 #include <cassert>
 
 #include "llama.h"
 #include "common.h"
-#include "chat-template.hpp"
+#include "chat.h"
 
 static std::string normalize_newlines(const std::string & s) {
 #ifdef _WIN32
@@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
 #endif
 }
 
+static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
+    common_chat_msg msg;
+    msg.role = role;
+    msg.content = content;
+    return msg;
+}
+
 int main(void) {
     std::vector<llama_chat_message> conversation {
         {"system", "You are a helpful assistant"},
@@ -50,7 +58,7 @@ int main(void) {
             /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
             /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
             /* .expected_output_jinja= */ "",
-            /* .bos_token= */ "",
+            /* .bos_token= */ "<s>",
             /* .eos_token= */ "</s>",
         },
         {
@@ -72,8 +80,8 @@ int main(void) {
         {
             /* .name= */ "mlabonne/AlphaMonarch-7B",
             /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
-            /* .expected_output= */          "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
-            /* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+            /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n   I am an assistant   </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
+            /* .expected_output_jinja= */ "",
             /* .bos_token= */ "<s>",
             /* .eos_token= */ "</s>",
         },
@@ -87,7 +95,7 @@ int main(void) {
             /* .name= */ "OrionStarAI/Orion-14B-Chat",
             /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
             /* .expected_output= */       "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: </s>",
-            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: </s>",
+            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s>   I am an assistant   </s>Human: Another question\n\nAssistant: ",
             /* .bos_token= */ "",
             /* .eos_token= */ "</s>",
         },
@@ -304,12 +312,9 @@ int main(void) {
         }
     }
 
-    json messages = json::array();
+    std::vector<common_chat_msg> messages;
     for (const auto & msg : conversation) {
-        messages.push_back({
-            {"role", msg.role},
-            {"content", msg.content},
-        });
+        messages.push_back(simple_msg(msg.role, msg.content));
     }
     for (const auto & test_case : test_cases) {
         if (!test_case.supported_with_jinja) {
@@ -317,8 +322,13 @@ int main(void) {
         }
         printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
         try {
-            minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
-            auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
+            auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
+            common_chat_templates_inputs inputs;
+            inputs.use_jinja = true;
+            inputs.messages = messages;
+            inputs.add_generation_prompt = add_generation_prompt;
+            auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
+            output = normalize_newlines(output);
             auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
             if (output != expected_output) {
                 printf("Expected:\n%s\n", expected_output.c_str());
@@ -336,11 +346,11 @@ int main(void) {
     // test llama_chat_format_single for system message
     printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
     std::vector<common_chat_msg> chat2;
-    common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
+    auto sys_msg = simple_msg("system", "You are a helpful assistant");
 
     auto fmt_sys = [&](std::string tmpl_str) {
-        minja::chat_template tmpl(tmpl_str, "", "");
-        auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
+        auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
         printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;
@@ -360,14 +370,14 @@ int main(void) {
 
     // test llama_chat_format_single for user message
     printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
-    chat2.push_back({"system", "You are a helpful assistant", {}});
-    chat2.push_back({"user", "Hello", {}});
-    chat2.push_back({"assistant", "I am assistant", {}});
-    common_chat_msg new_msg{"user", "How are you", {}};
+    chat2.push_back(simple_msg("system", "You are a helpful assistant"));
+    chat2.push_back(simple_msg("user", "Hello"));
+    chat2.push_back(simple_msg("assistant", "I am assistant"));
+    auto new_msg = simple_msg("user", "How are you");
 
-    auto fmt_single = [&](std::string tmpl_str) {
-        minja::chat_template tmpl(tmpl_str, "", "");
-        auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
+    auto fmt_single = [&](const std::string & tmpl_str) {
+        auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
+        auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
         printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;

Dosya farkı çok büyük olduğundan ihmal edildi
+ 408 - 357
tests/test-chat.cpp


Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor