Просмотр исходного кода

Add Jinja template support (#11016)

* Copy minja from https://github.com/google/minja/commit/58f0ca6dd74bcbfbd4e71229736640322b31c7f9

* Add --jinja and --chat-template-file flags

* Add missing <optional> include

* Avoid print in get_hf_chat_template.py

* No designated initializers yet

* Try and work around msvc++ non-macro max resolution quirk

* Update test_chat_completion.py

* Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template

* Refactor test-chat-template

* Test templates w/ minja

* Fix deprecation

* Add --jinja to llama-run

* Update common_chat_format_example to use minja template wrapper

* Test chat_template in e2e test

* Update utils.py

* Update test_chat_completion.py

* Update run.cpp

* Update arg.cpp

* Refactor common_chat_* functions to accept minja template + use_jinja option

* Attempt to fix linkage of LLAMA_CHATML_TEMPLATE

* Revert LLAMA_CHATML_TEMPLATE refactor

* Normalize newlines in test-chat-templates for windows tests

* Forward decl minja::chat_template to avoid eager json dep

* Flush stdout in chat template before potential crash

* Fix copy elision warning

* Rm unused optional include

* Add missing optional include to server.cpp

* Disable jinja test that has a cryptic windows failure

* minja: fix vigogne (https://github.com/google/minja/pull/22)

* Apply suggestions from code review

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Finish suggested renamings

* Move chat_templates inside server_context + remove mutex

* Update --chat-template-file w/ recent change to --chat-template

* Refactor chat template validation

* Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr)

* Warn against missing eos / bos tokens when jinja template references them

* rename: common_chat_template[s]

* reinstate assert on chat_templates.template_default

* Update minja to https://github.com/google/minja/commit/b8437df626ac6cd0ce3b333b3c74ed1129c19f25

* Update minja to https://github.com/google/minja/pull/25

* Update minja from https://github.com/google/minja/pull/27

* rm unused optional header

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Olivier Chafik 1 год назад
Родитель
Сommit
6171c9d258

+ 2 - 0
Makefile

@@ -1361,7 +1361,9 @@ llama-server: \
 	examples/server/httplib.h \
 	examples/server/index.html.hpp \
 	examples/server/loading.html.hpp \
+	common/chat-template.hpp \
 	common/json.hpp \
+	common/minja.hpp \
 	$(OBJ_ALL)
 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
 	$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)

+ 2 - 0
common/CMakeLists.txt

@@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
     arg.cpp
     arg.h
     base64.hpp
+    chat-template.hpp
     common.cpp
     common.h
     console.cpp
@@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
     json.hpp
     log.cpp
     log.h
+    minja.hpp
     ngram-cache.cpp
     ngram-cache.h
     sampling.cpp

+ 35 - 7
common/arg.cpp

@@ -325,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
     }
 
+    if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
+        throw std::runtime_error(string_format(
+            "error: the supplied chat template is not supported: %s%s\n",
+            params.chat_template.c_str(),
+            params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
+        ));
+    }
+
     return true;
 }
 
@@ -1947,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--jinja"},
+        "use jinja template for chat (default: disabled)",
+        [](common_params & params) {
+            params.use_jinja = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
     add_opt(common_arg(
         {"--chat-template"}, "JINJA_TEMPLATE",
         string_format(
             "set custom jinja chat template (default: template taken from model's metadata)\n"
             "if suffix/prefix are specified, template will be disabled\n"
+            "only commonly used templates are accepted (unless --jinja is set before this flag):\n"
             "list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
         ),
         [](common_params & params, const std::string & value) {
-            if (!common_chat_verify_template(value)) {
-                throw std::runtime_error(string_format(
-                    "error: the supplied chat template is not supported: %s\n"
-                    "note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
-                    value.c_str()
-                ));
-            }
             params.chat_template = value;
         }
     ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
+    add_opt(common_arg(
+        {"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
+        string_format(
+            "set custom jinja chat template file (default: template taken from model's metadata)\n"
+            "if suffix/prefix are specified, template will be disabled\n"
+            "only commonly used templates are accepted (unless --jinja is set before this flag):\n"
+            "list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
+        ),
+        [](common_params & params, const std::string & value) {
+            std::ifstream file(value);
+            if (!file) {
+                throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
+            }
+            std::copy(
+                std::istreambuf_iterator<char>(file),
+                std::istreambuf_iterator<char>(),
+                std::back_inserter(params.chat_template));
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
     add_opt(common_arg(
         {"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
         string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

+ 249 - 0
common/chat-template.hpp

@@ -0,0 +1,249 @@
+/*
+    Copyright 2024 Google LLC
+
+    Use of this source code is governed by an MIT-style
+    license that can be found in the LICENSE file or at
+    https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include "minja.hpp"
+#include <json.hpp>
+#include <string>
+#include <vector>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+class chat_template {
+  public:
+
+  private:
+    bool supports_tools_ = true;
+    // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
+    // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
+    bool requires_object_arguments_ = false;
+    bool supports_system_role_ = true;
+    bool supports_parallel_tool_calls_ = false;
+    std::string source_;
+    std::string bos_token_;
+    std::string eos_token_;
+    std::shared_ptr<minja::TemplateNode> template_root_;
+
+    std::string try_render(
+        const nlohmann::ordered_json & messages,
+        const nlohmann::ordered_json & tools,
+        bool add_generation_prompt,
+        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
+    {
+        try {
+            auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
+            // fprintf(stderr, "Prompt: %s\n", prompt.c_str());
+            return prompt;
+        } catch (const std::exception & e) {
+            // fprintf(stderr, "Error: %s\n", e.what());
+            return "";
+        }
+    }
+
+  public:
+    chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
+        : source_(source), bos_token_(bos_token), eos_token_(eos_token)
+    {
+        template_root_ = minja::Parser::parse(source_, {
+            /* .trim_blocks = */ true,
+            /* .lstrip_blocks = */ true,
+            /* .keep_trailing_newline = */ false,
+        });
+        supports_tools_ = source.find("tools") != std::string::npos;
+
+        auto renders_string_arguments =
+            try_render({
+                {
+                    {"role", "user"},
+                    {"content", "Hey"}
+                },
+                {
+                    {"role", "assistant"},
+                    {"tool_calls", json::array({
+                        {
+                            {"id", "call_1___"},
+                            {"type", "function"},
+                            {"function", {
+                                {"arguments", "{\"code\": \"print('Hello, World!')\"}"},
+                                {"name", "ipython"},
+                            }},
+                        },
+                    })},
+                }
+            }, {}, false).find("{\"code\": \"print") != std::string::npos;
+        if (!renders_string_arguments) {
+            auto renders_object_arguments =
+                try_render({
+                    {
+                        {"role", "user"},
+                        {"content", "Hey"}
+                    },
+                    {
+                        {"role", "assistant"},
+                        {"tool_calls", json::array({
+                            {
+                                {"id", "call_1___"},
+                                {"type", "function"},
+                                {"function", {
+                                    {"arguments", {
+                                        {"code", "print('Hello, World!')"},
+                                    }},
+                                    {"name", "ipython"},
+                                }},
+                            },
+                        })},
+                    }
+                }, {}, false).find("{\"code\": \"print") != std::string::npos;
+            requires_object_arguments_ = renders_object_arguments;
+        }
+        supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
+
+        supports_system_role_ = try_render({
+            {{"role", "system"}, {"content", "<System Needle>"}},
+            {{"role", "user"},   {"content", "Hey"}}
+        }, {}, false).find("<System Needle>") != std::string::npos;
+    }
+
+    const std::string & source() const { return source_; }
+    const std::string & bos_token() const { return bos_token_; }
+    const std::string & eos_token() const { return eos_token_; }
+    bool supports_tools() const { return supports_tools_; }
+    bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
+
+    std::string apply(
+        const nlohmann::ordered_json & messages,
+        const nlohmann::ordered_json & tools,
+        bool add_generation_prompt,
+        const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
+    {
+        json actual_messages;
+
+        // First, "fix" messages so they have a chance to be rendered correctly by the template
+
+        if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
+            actual_messages = json::array();
+
+            std::string pending_system;
+            auto flush_sys = [&]() {
+                if (!pending_system.empty()) {
+                    actual_messages.push_back({
+                        {"role", "user"},
+                        {"content", pending_system},
+                    });
+                    pending_system.clear();
+                }
+            };
+            for (const auto & message_ : messages) {
+                auto message = message_;
+                if (!message.contains("role") || !message.contains("content")) {
+                    throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
+                }
+                std::string role = message.at("role");
+
+                if (message.contains("tool_calls")) {
+                    if (requires_object_arguments_ || !supports_tools_) {
+                        for (auto & tool_call : message.at("tool_calls")) {
+                            if (tool_call["type"] == "function") {
+                                auto & function = tool_call.at("function");
+                                std::string arguments = function.at("arguments");
+                                function["arguments"] = json::parse(arguments);
+                            }
+                        }
+                    }
+                    if (!supports_tools_) {
+                        auto content = message.at("content");
+                        auto tool_calls = json::array();
+                        for (const auto & tool_call : message.at("tool_calls")) {
+                            if (tool_call.at("type") != "function") {
+                                continue;
+                            }
+                            const auto & function = tool_call.at("function");
+                            auto tc = json {
+                                {"name", function.at("name")},
+                                {"arguments", function.at("arguments")},
+                            };
+                            if (tool_call.contains("id")) {
+                                tc["id"] = tool_call["id"];
+                            }
+                            tool_calls.push_back(tc);
+                        }
+                        auto obj = json {
+                            {"tool_calls", tool_calls},
+                        };
+                        if (!content.is_null() && content != "") {
+                            obj["content"] = content;
+                        }
+                        message["content"] = obj.dump(2);
+                        message.erase("tool_calls");
+                    }
+                }
+                if (!supports_tools_ && role == "tool") {
+                    message["role"] = "user";
+                    auto obj = json {
+                        {"tool_response", {
+                            {"tool", message.at("name")},
+                            {"content", message.at("content")},
+                        }},
+                    };
+                    if (message.contains("tool_call_id")) {
+                        obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
+                    }
+                    message["content"] = obj.dump(2);
+                    message.erase("name");
+                }
+
+                if (!message["content"].is_null() && !supports_system_role_) {
+                    std::string content = message.at("content");
+                    if (role == "system") {
+                        if (!pending_system.empty()) pending_system += "\n";
+                        pending_system += content;
+                        continue;
+                    } else {
+                        if (role == "user") {
+                            if (!pending_system.empty()) {
+                                message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
+                                pending_system.clear();
+                            }
+                        } else {
+                            flush_sys();
+                        }
+                    }
+                }
+                actual_messages.push_back(message);
+            }
+            flush_sys();
+        } else {
+            actual_messages = messages;
+        }
+
+        auto context = minja::Context::make(json({
+            {"messages", actual_messages},
+            {"add_generation_prompt", add_generation_prompt},
+            {"bos_token", bos_token_},
+            {"eos_token", eos_token_},
+        }));
+
+        if (!tools.is_null()) {
+            auto tools_val = minja::Value(tools);
+            context->set("tools", tools_val);
+        }
+        if (!extra_context.is_null()) {
+            for (auto & kv : extra_context.items()) {
+                minja::Value val(kv.value());
+                context->set(kv.key(), val);
+            }
+        }
+
+        return template_root_->render(context);
+    }
+};
+
+}  // namespace minja

+ 94 - 32
common/common.cpp

@@ -12,6 +12,7 @@
 #include "json.hpp"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
+#include "chat-template.hpp"
 
 #include <algorithm>
 #include <cinttypes>
@@ -1728,67 +1729,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
 // Chat template utils
 //
 
-std::string common_get_builtin_chat_template(const struct llama_model * model) {
-    const char * ptr_tmpl = llama_model_chat_template(model);
-    return ptr_tmpl == nullptr ? "" : ptr_tmpl;
-}
-
-bool common_chat_verify_template(const std::string & tmpl) {
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
+    if (use_jinja) {
+        try {
+            auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
+            chat_template.apply({{
+                {"role", "user"},
+                {"content", "test"},
+            }}, json(), true);
+            return true;
+        } catch (const std::exception & e) {
+            LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
+            return false;
+        }
+    }
     llama_chat_message chat[] = {{"user", "test"}};
     const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
     return res >= 0;
 }
 
-std::string common_chat_apply_template(const struct llama_model * model,
-        const std::string & tmpl,
+std::string common_chat_apply_template(
+        const common_chat_template & tmpl,
         const std::vector<common_chat_msg> & msgs,
-        bool add_ass) {
+        bool add_ass,
+        bool use_jinja) {
+    if (use_jinja) {
+        auto messages = json::array();
+        for (const auto & msg : msgs) {
+            messages.push_back({{"role", msg.role}, {"content", msg.content}});
+        }
+        return tmpl.apply(messages, /* tools= */ json(), add_ass);
+    }
+
     int alloc_size = 0;
-    bool fallback = false; // indicate if we must fallback to default chatml
     std::vector<llama_chat_message> chat;
     for (const auto & msg : msgs) {
         chat.push_back({msg.role.c_str(), msg.content.c_str()});
         alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
     }
 
-    const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
     std::vector<char> buf(alloc_size);
 
     // run the first time to get the total output length
-    int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+    int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
 
     // error: chat template is not supported
     if (res < 0) {
-        if (ptr_tmpl != nullptr) {
-            // if the custom "tmpl" is not supported, we throw an error
-            // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
-            throw std::runtime_error("this custom template is not supported");
-        }
-
-        // If the built-in template is not supported, we default to chatml
-        res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
-        fallback = true;
+        // if the custom "tmpl" is not supported, we throw an error
+        // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+        throw std::runtime_error("this custom template is not supported");
     }
 
     // if it turns out that our buffer is too small, we resize it
     if ((size_t) res > buf.size()) {
         buf.resize(res);
-        res = llama_chat_apply_template(
-            fallback ? "chatml" : ptr_tmpl,
-            chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+        res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
     }
 
     std::string formatted_chat(buf.data(), res);
     return formatted_chat;
 }
 
-std::string common_chat_format_single(const struct llama_model * model,
-        const std::string & tmpl,
+std::string common_chat_format_single(
+        const common_chat_template & tmpl,
         const std::vector<common_chat_msg> & past_msg,
         const common_chat_msg & new_msg,
-        bool add_ass) {
+        bool add_ass,
+        bool use_jinja) {
     std::ostringstream ss;
-    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
+    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
     std::vector<common_chat_msg> chat_new(past_msg);
     // if the past_msg ends with a newline, we must preserve it in the formatted version
     if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@@ -1796,21 +1805,74 @@ std::string common_chat_format_single(const struct llama_model * model,
     };
     // format chat with new_msg
     chat_new.push_back(new_msg);
-    auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
+    auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
     // get the diff part
     ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
     return ss.str();
 }
 
-std::string common_chat_format_example(const struct llama_model * model,
-        const std::string & tmpl) {
+std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
     std::vector<common_chat_msg> msgs = {
         {"system",    "You are a helpful assistant"},
         {"user",      "Hello"},
         {"assistant", "Hi there"},
         {"user",      "How are you?"},
     };
-    return common_chat_apply_template(model, tmpl, msgs, true);
+    return common_chat_apply_template(tmpl, msgs, true, use_jinja);
+}
+
+common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
+{
+    auto vocab = llama_model_get_vocab(model);
+    std::string default_template_src = chat_template_override;
+    std::string template_tool_use_src = chat_template_override;
+    bool has_explicit_template = !chat_template_override.empty();
+    if (chat_template_override.empty()) {
+        auto str = llama_model_chat_template(model, /* name */ nullptr);
+        if (str) {
+            default_template_src = str;
+            has_explicit_template = true;
+        }
+        str = llama_model_chat_template(model, /* name */ "tool_use");
+        if (str) {
+            template_tool_use_src = str;
+            has_explicit_template = true;
+        }
+    }
+    if (default_template_src.empty() || default_template_src == "chatml") {
+        if (!template_tool_use_src.empty()) {
+            default_template_src = template_tool_use_src;
+        } else {
+            default_template_src = R"(
+                {%- for message in messages -%}
+                    {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
+                {%- endfor -%}
+                {%- if add_generation_prompt -%}
+                    {{- "<|im_start|>assistant\n" -}}
+                {%- endif -%}
+            )";
+        }
+    }
+    const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
+        if (token == LLAMA_TOKEN_NULL) {
+            if (default_template_src.find(jinja_variable_name) != std::string::npos
+                || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
+                LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
+            }
+            return std::string();
+        } else {
+            return common_token_to_piece(vocab, token, true);
+        }
+    };
+    auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
+    auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
+    return {
+        has_explicit_template,
+        std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
+        template_tool_use_src.empty()
+            ? nullptr
+            : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
+    };
 }
 
 //

+ 26 - 12
common/common.h

@@ -334,6 +334,7 @@ struct common_params {
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
     std::string chat_template = "";                                                                         // NOLINT
+    bool use_jinja = false;                                                                                 // NOLINT
     bool enable_chat_template = true;
 
     std::vector<std::string> api_keys;
@@ -603,30 +604,43 @@ struct common_chat_msg {
     std::string content;
 };
 
-// Get the built-in chat template for the model. Return empty string if not present.
-std::string common_get_builtin_chat_template(const struct llama_model * model);
-
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
-bool common_chat_verify_template(const std::string & tmpl);
+bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
+
+namespace minja {
+    class chat_template;
+}
+
+typedef minja::chat_template common_chat_template;
+
+struct common_chat_templates {
+    bool has_explicit_template; // Model had builtin template or template overridde was specified.
+    std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
+    std::unique_ptr<common_chat_template> template_tool_use;
+};
 
 // CPP wrapper for llama_chat_apply_template
 // If the built-in template is not supported, we default to chatml
 // If the custom "tmpl" is not supported, we throw an error
-std::string common_chat_apply_template(const struct llama_model * model,
-        const std::string & tmpl,
+std::string common_chat_apply_template(
+        const common_chat_template & tmpl,
         const std::vector<common_chat_msg> & chat,
-        bool add_ass);
+        bool add_ass,
+        bool use_jinja);
 
 // Format single message, while taking into account the position of that message in chat history
-std::string common_chat_format_single(const struct llama_model * model,
-        const std::string & tmpl,
+std::string common_chat_format_single(
+        const common_chat_template & tmpl,
         const std::vector<common_chat_msg> & past_msg,
         const common_chat_msg & new_msg,
-        bool add_ass);
+        bool add_ass,
+        bool use_jinja);
 
 // Returns an example of formatted chat
-std::string common_chat_format_example(const struct llama_model * model,
-        const std::string & tmpl);
+std::string common_chat_format_example(
+    const common_chat_template & tmpl, bool use_jinja);
+
+common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
 
 //
 // KV cache utils

+ 2788 - 0
common/minja.hpp

@@ -0,0 +1,2788 @@
+/*
+    Copyright 2024 Google LLC
+
+    Use of this source code is governed by an MIT-style
+    license that can be found in the LICENSE file or at
+    https://opensource.org/licenses/MIT.
+*/
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <regex>
+#include <memory>
+#include <stdexcept>
+#include <sstream>
+#include <unordered_set>
+#include <json.hpp>
+
+using json = nlohmann::ordered_json;
+
+namespace minja {
+
+class Context;
+
+struct Options {
+    bool trim_blocks;  // removes the first newline after a block
+    bool lstrip_blocks;  // removes leading whitespace on the line of the block
+    bool keep_trailing_newline;  // don't remove last newline
+};
+
+struct ArgumentsValue;
+
+inline std::string normalize_newlines(const std::string & s) {
+#ifdef _WIN32
+  static const std::regex nl_regex("\r\n");
+  return std::regex_replace(s, nl_regex, "\n");
+#else
+  return s;
+#endif
+}
+
+/* Values that behave roughly like in Python. */
+class Value : public std::enable_shared_from_this<Value> {
+public:
+  using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+  using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
+
+private:
+  using ObjectType = nlohmann::ordered_map<json, Value>;  // Only contains primitive keys
+  using ArrayType = std::vector<Value>;
+
+  std::shared_ptr<ArrayType> array_;
+  std::shared_ptr<ObjectType> object_;
+  std::shared_ptr<CallableType> callable_;
+  json primitive_;
+
+  Value(const std::shared_ptr<ArrayType> & array) : array_(array) {}
+  Value(const std::shared_ptr<ObjectType> & object) : object_(object) {}
+  Value(const std::shared_ptr<CallableType> & callable) : object_(std::make_shared<ObjectType>()), callable_(callable) {}
+
+  /* Python-style string repr */
+  static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') {
+    if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump());
+    auto s = primitive.dump();
+    if (string_quote == '"' || s.find('\'') != std::string::npos) {
+      out << s;
+      return;
+    }
+    // Reuse json dump, just changing string quotes
+    out << string_quote;
+    for (size_t i = 1, n = s.size() - 1; i < n; ++i) {
+      if (s[i] == '\\' && s[i + 1] == '"') {
+        out << '"';
+        i++;
+      } else if (s[i] == string_quote) {
+        out << '\\' << string_quote;
+      } else {
+        out << s[i];
+      }
+    }
+    out << string_quote;
+  }
+  void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
+    auto print_indent = [&](int level) {
+      if (indent > 0) {
+          out << "\n";
+          for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
+      }
+    };
+    auto print_sub_sep = [&]() {
+      out << ',';
+      if (indent < 0) out << ' ';
+      else print_indent(level + 1);
+    };
+
+    auto string_quote = to_json ? '"' : '\'';
+
+    if (is_null()) out << "null";
+    else if (array_) {
+      out << "[";
+      print_indent(level + 1);
+      for (size_t i = 0; i < array_->size(); ++i) {
+        if (i) print_sub_sep();
+        (*array_)[i].dump(out, indent, level + 1, to_json);
+      }
+      print_indent(level);
+      out << "]";
+    } else if (object_) {
+      out << "{";
+      print_indent(level + 1);
+      for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) {
+        if (it != begin) print_sub_sep();
+        if (it->first.is_string()) {
+          dump_string(it->first, out, string_quote);
+        } else {
+          out << string_quote << it->first.dump() << string_quote;
+        }
+        out << ": ";
+        it->second.dump(out, indent, level + 1, to_json);
+      }
+      print_indent(level);
+      out << "}";
+    } else if (callable_) {
+      throw std::runtime_error("Cannot dump callable to JSON");
+    } else if (is_boolean() && !to_json) {
+      out << (this->to_bool() ? "True" : "False");
+    } else if (is_string() && !to_json) {
+      dump_string(primitive_, out, string_quote);
+    } else {
+      out << primitive_.dump();
+    }
+  }
+
+public:
+  Value() {}
+  Value(const bool& v) : primitive_(v) {}
+  Value(const int64_t & v) : primitive_(v) {}
+  Value(const double& v) : primitive_(v) {}
+  Value(const std::nullptr_t &) {}
+  Value(const std::string & v) : primitive_(v) {}
+  Value(const char * v) : primitive_(std::string(v)) {}
+
+  Value(const json & v) {
+    if (v.is_object()) {
+      auto object = std::make_shared<ObjectType>();
+      for (auto it = v.begin(); it != v.end(); ++it) {
+        (*object)[it.key()] = it.value();
+      }
+      object_ = std::move(object);
+    } else if (v.is_array()) {
+      auto array = std::make_shared<ArrayType>();
+      for (const auto& item : v) {
+        array->push_back(Value(item));
+      }
+      array_ = array;
+    } else {
+      primitive_ = v;
+    }
+  }
+
+  std::vector<Value> keys() {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    std::vector<Value> res;
+    for (const auto& item : *object_) {
+      res.push_back(item.first);
+    }
+    return res;
+  }
+
+  size_t size() const {
+    if (is_object()) return object_->size();
+    if (is_array()) return array_->size();
+    if (is_string()) return primitive_.get<std::string>().length();
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+
+  static Value array(const std::vector<Value> values = {}) {
+    auto array = std::make_shared<ArrayType>();
+    for (const auto& item : values) {
+      array->push_back(item);
+    }
+    return Value(array);
+  }
+  static Value object(const std::shared_ptr<ObjectType> object = std::make_shared<ObjectType>()) {
+    return Value(object);
+  }
+  static Value callable(const CallableType & callable) {
+    return Value(std::make_shared<CallableType>(callable));
+  }
+
+  void insert(size_t index, const Value& v) {
+    if (!array_)
+      throw std::runtime_error("Value is not an array: " + dump());
+    array_->insert(array_->begin() + index, v);
+  }
+  void push_back(const Value& v) {
+    if (!array_)
+      throw std::runtime_error("Value is not an array: " + dump());
+    array_->push_back(v);
+  }
+  Value pop(const Value& index) {
+    if (is_array()) {
+      if (array_->empty())
+        throw std::runtime_error("pop from empty list");
+      if (index.is_null()) {
+        auto ret = array_->back();
+        array_->pop_back();
+        return ret;
+      } else if (!index.is_number_integer()) {
+        throw std::runtime_error("pop index must be an integer: " + index.dump());
+      } else {
+        auto i = index.get<int>();
+        if (i < 0 || i >= static_cast<int>(array_->size()))
+          throw std::runtime_error("pop index out of range: " + index.dump());
+        auto it = array_->begin() + (i < 0 ? array_->size() + i : i);
+        auto ret = *it;
+        array_->erase(it);
+        return ret;
+      }
+    } else if (is_object()) {
+      if (!index.is_hashable())
+        throw std::runtime_error("Unashable type: " + index.dump());
+      auto it = object_->find(index.primitive_);
+      if (it == object_->end())
+        throw std::runtime_error("Key not found: " + index.dump());
+      auto ret = it->second;
+      object_->erase(it);
+      return ret;
+    } else {
+      throw std::runtime_error("Value is not an array or object: " + dump());
+    }
+  }
+  Value get(const Value& key) {
+    if (array_) {
+      if (!key.is_number_integer()) {
+        return Value();
+      }
+      auto index = key.get<int>();
+      return array_->at(index < 0 ? array_->size() + index : index);
+    } else if (object_) {
+      if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+      auto it = object_->find(key.primitive_);
+      if (it == object_->end()) return Value();
+      return it->second;
+    }
+    return Value();
+  }
+  void set(const Value& key, const Value& value) {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    (*object_)[key.primitive_] = value;
+  }
+  Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
+    if (!callable_) throw std::runtime_error("Value is not callable: " + dump());
+    return (*callable_)(context, args);
+  }
+
+  bool is_object() const { return !!object_; }
+  bool is_array() const { return !!array_; }
+  bool is_callable() const { return !!callable_; }
+  bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; }
+  bool is_boolean() const { return primitive_.is_boolean(); }
+  bool is_number_integer() const { return primitive_.is_number_integer(); }
+  bool is_number_float() const { return primitive_.is_number_float(); }
+  bool is_number() const { return primitive_.is_number(); }
+  bool is_string() const { return primitive_.is_string(); }
+  bool is_iterable() const { return is_array() || is_object() || is_string(); }
+
+  bool is_primitive() const { return !array_ && !object_ && !callable_; }
+  bool is_hashable() const { return is_primitive(); }
+
+  bool empty() const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_string()) return primitive_.empty();
+    if (is_array()) return array_->empty();
+    if (is_object()) return object_->empty();
+    return false;
+  }
+
+  void for_each(const std::function<void(Value &)> & callback) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (array_) {
+      for (auto& item : *array_) {
+        callback(item);
+      }
+    } else if (object_) {
+      for (auto & item : *object_) {
+        Value key(item.first);
+        callback(key);
+      }
+    } else if (is_string()) {
+      for (char c : primitive_.get<std::string>()) {
+        auto val = Value(std::string(1, c));
+        callback(val);
+      }
+    } else {
+      throw std::runtime_error("Value is not iterable: " + dump());
+    }
+  }
+
+  bool to_bool() const {
+    if (is_null()) return false;
+    if (is_boolean()) return get<bool>();
+    if (is_number()) return get<double>() != 0;
+    if (is_string()) return !get<std::string>().empty();
+    if (is_array()) return !empty();
+    return true;
+  }
+
+  int64_t to_int() const {
+    if (is_null()) return 0;
+    if (is_boolean()) return get<bool>() ? 1 : 0;
+    if (is_number()) return static_cast<int64_t>(get<double>());
+    if (is_string()) {
+      try {
+        return std::stol(get<std::string>());
+      } catch (const std::exception &) {
+        return 0;
+      }
+    }
+    return 0;
+  }
+
+  bool operator<(const Value & other) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_number() && other.is_number()) return get<double>() < other.get<double>();
+    if (is_string() && other.is_string()) return get<std::string>() < other.get<std::string>();
+    throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump());
+  }
+  bool operator>=(const Value & other) const { return !(*this < other); }
+
+  bool operator>(const Value & other) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_number() && other.is_number()) return get<double>() > other.get<double>();
+    if (is_string() && other.is_string()) return get<std::string>() > other.get<std::string>();
+    throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump());
+  }
+  bool operator<=(const Value & other) const { return !(*this > other); }
+
+  bool operator==(const Value & other) const {
+    if (callable_ || other.callable_) {
+      if (callable_.get() != other.callable_.get()) return false;
+    }
+    if (array_) {
+      if (!other.array_) return false;
+      if (array_->size() != other.array_->size()) return false;
+      for (size_t i = 0; i < array_->size(); ++i) {
+        if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false;
+      }
+      return true;
+    } else if (object_) {
+      if (!other.object_) return false;
+      if (object_->size() != other.object_->size()) return false;
+      for (const auto& item : *object_) {
+        if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false;
+      }
+      return true;
+    } else {
+      return primitive_ == other.primitive_;
+    }
+  }
+  bool operator!=(const Value & other) const { return !(*this == other); }
+
+  bool contains(const char * key) const { return contains(std::string(key)); }
+  bool contains(const std::string & key) const {
+    if (array_) {
+      return false;
+    } else if (object_) {
+      return object_->find(key) != object_->end();
+    } else {
+      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+    }
+  }
+  bool contains(const Value & value) const {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (array_) {
+      for (const auto& item : *array_) {
+        if (item.to_bool() && item == value) return true;
+      }
+      return false;
+    } else if (object_) {
+      if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
+      return object_->find(value.primitive_) != object_->end();
+    } else {
+      throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
+    }
+  }
+  void erase(size_t index) {
+    if (!array_) throw std::runtime_error("Value is not an array: " + dump());
+    array_->erase(array_->begin() + index);
+  }
+  void erase(const std::string & key) {
+    if (!object_) throw std::runtime_error("Value is not an object: " + dump());
+    object_->erase(key);
+  }
+  const Value& at(const Value & index) const {
+    return const_cast<Value*>(this)->at(index);
+  }
+  Value& at(const Value & index) {
+    if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
+    if (is_array()) return array_->at(index.get<int>());
+    if (is_object()) return object_->at(index.primitive_);
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+  const Value& at(size_t index) const {
+    return const_cast<Value*>(this)->at(index);
+  }
+  Value& at(size_t index) {
+    if (is_null())
+      throw std::runtime_error("Undefined value or reference");
+    if (is_array()) return array_->at(index);
+    if (is_object()) return object_->at(index);
+    throw std::runtime_error("Value is not an array or object: " + dump());
+  }
+
+  template <typename T>
+  T get(const std::string & key, T default_value) const {
+    if (!contains(key)) return default_value;
+    return at(key).get<T>();
+  }
+
+  template <typename T>
+  T get() const {
+    if (is_primitive()) return primitive_.get<T>();
+    throw std::runtime_error("get<T> not defined for this value type: " + dump());
+  }
+
+  std::string dump(int indent=-1, bool to_json=false) const {
+    std::ostringstream out;
+    dump(out, indent, 0, to_json);
+    return out.str();
+  }
+
+  Value operator-() const {
+      if (is_number_integer())
+        return -get<int64_t>();
+      else
+        return -get<double>();
+  }
+  std::string to_str() const {
+    if (is_string()) return get<std::string>();
+    if (is_number_integer()) return std::to_string(get<int64_t>());
+    if (is_number_float()) return std::to_string(get<double>());
+    if (is_boolean()) return get<bool>() ? "True" : "False";
+    if (is_null()) return "None";
+    return dump();
+  }
+  Value operator+(const Value& rhs) const {
+      if (is_string() || rhs.is_string()) {
+        return to_str() + rhs.to_str();
+      } else if (is_number_integer() && rhs.is_number_integer()) {
+        return get<int64_t>() + rhs.get<int64_t>();
+      } else if (is_array() && rhs.is_array()) {
+        auto res = Value::array();
+        for (const auto& item : *array_) res.push_back(item);
+        for (const auto& item : *rhs.array_) res.push_back(item);
+        return res;
+      } else {
+        return get<double>() + rhs.get<double>();
+      }
+  }
+  Value operator-(const Value& rhs) const {
+      if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() - rhs.get<int64_t>();
+      else
+        return get<double>() - rhs.get<double>();
+  }
+  Value operator*(const Value& rhs) const {
+      if (is_string() && rhs.is_number_integer()) {
+        std::ostringstream out;
+        for (int64_t i = 0, n = rhs.get<int64_t>(); i < n; ++i) {
+          out << to_str();
+        }
+        return out.str();
+      }
+      else if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() * rhs.get<int64_t>();
+      else
+        return get<double>() * rhs.get<double>();
+  }
+  Value operator/(const Value& rhs) const {
+      if (is_number_integer() && rhs.is_number_integer())
+        return get<int64_t>() / rhs.get<int64_t>();
+      else
+        return get<double>() / rhs.get<double>();
+  }
+  Value operator%(const Value& rhs) const {
+    return get<int64_t>() % rhs.get<int64_t>();
+  }
+};
+
+struct ArgumentsValue {
+  std::vector<Value> args;
+  std::vector<std::pair<std::string, Value>> kwargs;
+
+  bool has_named(const std::string & name) {
+    for (const auto & p : kwargs) {
+      if (p.first == name) return true;
+    }
+    return false;
+  }
+
+  Value get_named(const std::string & name) {
+    for (const auto & [key, value] : kwargs) {
+      if (key == name) return value;
+    }
+    return Value();
+  }
+
+  bool empty() {
+    return args.empty() && kwargs.empty();
+  }
+
+  void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) {
+    if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
+      std::ostringstream out;
+      out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments";
+      throw std::runtime_error(out.str());
+    }
+  }
+};
+
+template <>
+inline json Value::get<json>() const {
+  if (is_primitive()) return primitive_;
+  if (is_null()) return json();
+  if (array_) {
+    std::vector<json> res;
+    for (const auto& item : *array_) {
+      res.push_back(item.get<json>());
+    }
+    return res;
+  }
+  if (object_) {
+    json res = json::object();
+    for (const auto& [key, value] : *object_) {
+      if (key.is_string()) {
+        res[key.get<std::string>()] = value.get<json>();
+      } else if (key.is_primitive()) {
+        res[key.dump()] = value.get<json>();
+      } else {
+        throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
+      }
+    }
+    if (is_callable()) {
+      res["__callable__"] = true;
+    }
+    return res;
+  }
+  throw std::runtime_error("get<json> not defined for this value type: " + dump());
+}
+
+} // namespace minja
+
+namespace std {
+  template <>
+  struct hash<minja::Value> {
+    size_t operator()(const minja::Value & v) const {
+      if (!v.is_hashable())
+        throw std::runtime_error("Unsupported type for hashing: " + v.dump());
+      return std::hash<json>()(v.get<json>());
+    }
+  };
+} // namespace std
+
+namespace minja {
+
+static std::string error_location_suffix(const std::string & source, size_t pos) {
+  auto get_line = [&](size_t line) {
+    auto start = source.begin();
+    for (size_t i = 1; i < line; ++i) {
+      start = std::find(start, source.end(), '\n') + 1;
+    }
+    auto end = std::find(start, source.end(), '\n');
+    return std::string(start, end);
+  };
+  auto start = source.begin();
+  auto end = source.end();
+  auto it = start + pos;
+  auto line = std::count(start, it, '\n') + 1;
+  auto max_line = std::count(start, end, '\n') + 1;
+  auto col = pos - std::string(start, it).rfind('\n');
+  std::ostringstream out;
+  out << " at row " << line << ", column " << col << ":\n";
+  if (line > 1) out << get_line(line - 1) << "\n";
+  out << get_line(line) << "\n";
+  out << std::string(col - 1, ' ') << "^\n";
+  if (line < max_line) out << get_line(line + 1) << "\n";
+
+  return out.str();
+}
+
+class Context : public std::enable_shared_from_this<Context> {
+  protected:
+    Value values_;
+    std::shared_ptr<Context> parent_;
+  public:
+    Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
+        if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
+    }
+    virtual ~Context() {}
+
+    static std::shared_ptr<Context> builtins();
+    static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
+
+    std::vector<Value> keys() {
+        return values_.keys();
+    }
+    virtual Value get(const Value & key) {
+        if (values_.contains(key)) return values_.at(key);
+        if (parent_) return parent_->get(key);
+        return Value();
+    }
+    virtual Value & at(const Value & key) {
+        if (values_.contains(key)) return values_.at(key);
+        if (parent_) return parent_->at(key);
+        throw std::runtime_error("Undefined variable: " + key.dump());
+    }
+    virtual bool contains(const Value & key) {
+        if (values_.contains(key)) return true;
+        if (parent_) return parent_->contains(key);
+        return false;
+    }
+    virtual void set(const Value & key, Value & value) {
+        values_.set(key, value);
+    }
+};
+
+struct Location {
+    std::shared_ptr<std::string> source;
+    size_t pos;
+};
+
+class Expression {
+protected:
+    virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
+public:
+    using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
+
+    Location location;
+
+    Expression(const Location & location) : location(location) {}
+    virtual ~Expression() = default;
+
+    Value evaluate(const std::shared_ptr<Context> & context) const {
+        try {
+            return do_evaluate(context);
+        } catch (const std::exception & e) {
+            std::ostringstream out;
+            out << e.what();
+            if (location.source) out << error_location_suffix(*location.source, location.pos);
+            throw std::runtime_error(out.str());
+        }
+    }
+};
+
+class VariableExpr : public Expression {
+    std::string name;
+public:
+    VariableExpr(const Location & location, const std::string& n)
+      : Expression(location), name(n) {}
+    std::string get_name() const { return name; }
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!context->contains(name)) {
+            return Value();
+        }
+        return context->at(name);
+    }
+};
+
+static void destructuring_assign(const std::vector<std::string> & var_names, const std::shared_ptr<Context> & context, Value& item) {
+  if (var_names.size() == 1) {
+      Value name(var_names[0]);
+      context->set(name, item);
+  } else {
+      if (!item.is_array() || item.size() != var_names.size()) {
+          throw std::runtime_error("Mismatched number of variables and items in destructuring assignment");
+      }
+      for (size_t i = 0; i < var_names.size(); ++i) {
+          context->set(var_names[i], item.at(i));
+      }
+  }
+}
+
+enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
+
+class TemplateToken {
+public:
+    enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
+
+    static std::string typeToString(Type t) {
+        switch (t) {
+            case Type::Text: return "text";
+            case Type::Expression: return "expression";
+            case Type::If: return "if";
+            case Type::Else: return "else";
+            case Type::Elif: return "elif";
+            case Type::EndIf: return "endif";
+            case Type::For: return "for";
+            case Type::EndFor: return "endfor";
+            case Type::Set: return "set";
+            case Type::EndSet: return "endset";
+            case Type::Comment: return "comment";
+            case Type::Macro: return "macro";
+            case Type::EndMacro: return "endmacro";
+            case Type::Filter: return "filter";
+            case Type::EndFilter: return "endfilter";
+        }
+        return "Unknown";
+    }
+
+    TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {}
+    virtual ~TemplateToken() = default;
+
+    Type type;
+    Location location;
+    SpaceHandling pre_space = SpaceHandling::Keep;
+    SpaceHandling post_space = SpaceHandling::Keep;
+};
+
+struct TextTemplateToken : public TemplateToken {
+    std::string text;
+    TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {}
+};
+
+struct ExpressionTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> expr;
+    ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
+};
+
+struct IfTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> condition;
+    IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElifTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> condition;
+    ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
+};
+
+struct ElseTemplateToken : public TemplateToken {
+    ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {}
+};
+
+struct EndIfTemplateToken : public TemplateToken {
+    EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
+};
+
+struct MacroTemplateToken : public TemplateToken {
+    std::shared_ptr<VariableExpr> name;
+    Expression::Parameters params;
+    MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
+      : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
+};
+
+struct EndMacroTemplateToken : public TemplateToken {
+    EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {}
+};
+
+struct FilterTemplateToken : public TemplateToken {
+    std::shared_ptr<Expression> filter;
+    FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && filter)
+      : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {}
+};
+
+struct EndFilterTemplateToken : public TemplateToken {
+    EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {}
+};
+
+struct ForTemplateToken : public TemplateToken {
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> iterable;
+    std::shared_ptr<Expression> condition;
+    bool recursive;
+    ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
+      std::shared_ptr<Expression> && c, bool r)
+      : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
+};
+
+struct EndForTemplateToken : public TemplateToken {
+    EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
+};
+
+struct SetTemplateToken : public TemplateToken {
+    std::string ns;
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> value;
+    SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+      : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
+};
+
+struct EndSetTemplateToken : public TemplateToken {
+    EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {}
+};
+
+struct CommentTemplateToken : public TemplateToken {
+    std::string text;
+    CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
+};
+
+class TemplateNode {
+    Location location_;
+protected:
+    virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
+
+public:
+    TemplateNode(const Location & location) : location_(location) {}
+    void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
+        try {
+            do_render(out, context);
+        } catch (const std::exception & e) {
+            std::ostringstream err;
+            err << e.what();
+            if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
+            throw std::runtime_error(err.str());
+        }
+    }
+    const Location & location() const { return location_; }
+    virtual ~TemplateNode() = default;
+    std::string render(const std::shared_ptr<Context> & context) const {
+        std::ostringstream out;
+        render(out, context);
+        return out.str();
+    }
+};
+
+class SequenceNode : public TemplateNode {
+    std::vector<std::shared_ptr<TemplateNode>> children;
+public:
+    SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
+      : TemplateNode(location), children(std::move(c)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+        for (const auto& child : children) child->render(out, context);
+    }
+};
+
+class TextNode : public TemplateNode {
+    std::string text;
+public:
+    TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> &) const override {
+      out << text;
+    }
+};
+
+class ExpressionNode : public TemplateNode {
+    std::shared_ptr<Expression> expr;
+public:
+    ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
+      auto result = expr->evaluate(context);
+      if (result.is_string()) {
+          out << result.get<std::string>();
+      } else if (result.is_boolean()) {
+          out << (result.get<bool>() ? "True" : "False");
+      } else if (!result.is_null()) {
+          out << result.dump();
+      }
+  }
+};
+
+class IfNode : public TemplateNode {
+    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+public:
+    IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
+        : TemplateNode(location), cascade(std::move(c)) {}
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      for (const auto& branch : cascade) {
+          auto enter_branch = true;
+          if (branch.first) {
+            enter_branch = branch.first->evaluate(context).to_bool();
+          }
+          if (enter_branch) {
+            if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
+              branch.second->render(out, context);
+              return;
+          }
+      }
+    }
+};
+
+class ForNode : public TemplateNode {
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> iterable;
+    std::shared_ptr<Expression> condition;
+    std::shared_ptr<TemplateNode> body;
+    bool recursive;
+    std::shared_ptr<TemplateNode> else_body;
+public:
+    ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
+      std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
+            : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
+
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+      // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
+      if (!iterable) throw std::runtime_error("ForNode.iterable is null");
+      if (!body) throw std::runtime_error("ForNode.body is null");
+
+      auto iterable_value = iterable->evaluate(context);
+      Value::CallableType loop_function;
+
+      std::function<void(Value&)> visit = [&](Value& iter) {
+          auto filtered_items = Value::array();
+          if (!iter.is_null()) {
+            if (!iterable_value.is_iterable()) {
+              throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump());
+            }
+            iterable_value.for_each([&](Value & item) {
+                destructuring_assign(var_names, context, item);
+                if (!condition || condition->evaluate(context).to_bool()) {
+                  filtered_items.push_back(item);
+                }
+            });
+          }
+          if (filtered_items.empty()) {
+            if (else_body) {
+              else_body->render(out, context);
+            }
+          } else {
+              auto loop = recursive ? Value::callable(loop_function) : Value::object();
+              loop.set("length", (int64_t) filtered_items.size());
+
+              size_t cycle_index = 0;
+              loop.set("cycle", Value::callable([&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+                  if (args.args.empty() || !args.kwargs.empty()) {
+                      throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg");
+                  }
+                  auto item = args.args[cycle_index];
+                  cycle_index = (cycle_index + 1) % args.args.size();
+                  return item;
+              }));
+              auto loop_context = Context::make(Value::object(), context);
+              loop_context->set("loop", loop);
+              for (size_t i = 0, n = filtered_items.size(); i < n; ++i) {
+                  auto & item = filtered_items.at(i);
+                  destructuring_assign(var_names, loop_context, item);
+                  loop.set("index", (int64_t) i + 1);
+                  loop.set("index0", (int64_t) i);
+                  loop.set("revindex", (int64_t) (n - i));
+                  loop.set("revindex0", (int64_t) (n - i - 1));
+                  loop.set("length", (int64_t) n);
+                  loop.set("first", i == 0);
+                  loop.set("last", i == (n - 1));
+                  loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
+                  loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
+                  body->render(out, loop_context);
+              }
+          }
+      };
+
+      if (recursive) {
+        loop_function = [&](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+            if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) {
+                throw std::runtime_error("loop() expects exactly 1 positional iterable argument");
+            }
+            auto & items = args.args[0];
+            visit(items);
+            return Value();
+        };
+      }
+
+      visit(iterable_value);
+  }
+};
+
+class MacroNode : public TemplateNode {
+    std::shared_ptr<VariableExpr> name;
+    Expression::Parameters params;
+    std::shared_ptr<TemplateNode> body;
+    std::unordered_map<std::string, size_t> named_param_positions;
+public:
+    MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
+        : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
+        for (size_t i = 0; i < params.size(); ++i) {
+          const auto & name = params[i].first;
+          if (!name.empty()) {
+            named_param_positions[name] = i;
+          }
+        }
+    }
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
+        if (!name) throw std::runtime_error("MacroNode.name is null");
+        if (!body) throw std::runtime_error("MacroNode.body is null");
+        auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+            auto call_context = macro_context;
+            std::vector<bool> param_set(params.size(), false);
+            for (size_t i = 0, n = args.args.size(); i < n; i++) {
+                auto & arg = args.args[i];
+                if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
+                param_set[i] = true;
+                auto & param_name = params[i].first;
+                call_context->set(param_name, arg);
+            }
+            for (auto & [arg_name, value] : args.kwargs) {
+                auto it = named_param_positions.find(arg_name);
+                if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
+
+                call_context->set(arg_name, value);
+                param_set[it->second] = true;
+            }
+            // Set default values for parameters that were not passed
+            for (size_t i = 0, n = params.size(); i < n; i++) {
+                if (!param_set[i] && params[i].second != nullptr) {
+                    auto val = params[i].second->evaluate(context);
+                    call_context->set(params[i].first, val);
+                }
+            }
+            return body->render(call_context);
+        });
+        macro_context->set(name->get_name(), callable);
+    }
+};
+
+class FilterNode : public TemplateNode {
+    std::shared_ptr<Expression> filter;
+    std::shared_ptr<TemplateNode> body;
+
+public:
+    FilterNode(const Location & location, std::shared_ptr<Expression> && f, std::shared_ptr<TemplateNode> && b)
+        : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {}
+
+    void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
+        if (!filter) throw std::runtime_error("FilterNode.filter is null");
+        if (!body) throw std::runtime_error("FilterNode.body is null");
+        auto filter_value = filter->evaluate(context);
+        if (!filter_value.is_callable()) {
+            throw std::runtime_error("Filter must be a callable: " + filter_value.dump());
+        }
+        std::string rendered_body = body->render(context);
+
+        ArgumentsValue filter_args = {{Value(rendered_body)}, {}};
+        auto result = filter_value.call(context, filter_args);
+        out << result.to_str();
+    }
+};
+
+class SetNode : public TemplateNode {
+    std::string ns;
+    std::vector<std::string> var_names;
+    std::shared_ptr<Expression> value;
+public:
+    SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
+        : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+      if (!value) throw std::runtime_error("SetNode.value is null");
+      if (!ns.empty()) {
+        if (var_names.size() != 1) {
+          throw std::runtime_error("Namespaced set only supports a single variable name");
+        }
+        auto & name = var_names[0];
+        auto ns_value = context->get(ns);
+        if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
+        ns_value.set(name, this->value->evaluate(context));
+      } else {
+        auto val = value->evaluate(context);
+        destructuring_assign(var_names, context, val);
+      }
+    }
+};
+
+class SetTemplateNode : public TemplateNode {
+    std::string name;
+    std::shared_ptr<TemplateNode> template_value;
+public:
+    SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
+        : TemplateNode(location), name(name), template_value(std::move(tv)) {}
+    void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
+      if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
+      Value value { template_value->render(context) };
+      context->set(name, value);
+    }
+};
+
+class IfExpr : public Expression {
+    std::shared_ptr<Expression> condition;
+    std::shared_ptr<Expression> then_expr;
+    std::shared_ptr<Expression> else_expr;
+public:
+    IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
+        : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+      if (!condition) throw std::runtime_error("IfExpr.condition is null");
+      if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
+      if (condition->evaluate(context).to_bool()) {
+        return then_expr->evaluate(context);
+      }
+      if (else_expr) {
+        return else_expr->evaluate(context);
+      }
+      return nullptr;
+    }
+};
+
+class LiteralExpr : public Expression {
+    Value value;
+public:
+    LiteralExpr(const Location & location, const Value& v)
+      : Expression(location), value(v) {}
+    Value do_evaluate(const std::shared_ptr<Context> &) const override { return value; }
+};
+
+class ArrayExpr : public Expression {
+    std::vector<std::shared_ptr<Expression>> elements;
+public:
+    ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
+      : Expression(location), elements(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        auto result = Value::array();
+        for (const auto& e : elements) {
+            if (!e) throw std::runtime_error("Array element is null");
+            result.push_back(e->evaluate(context));
+        }
+        return result;
+    }
+};
+
+class DictExpr : public Expression {
+    std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+public:
+    DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
+      : Expression(location), elements(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        auto result = Value::object();
+        for (const auto& [key, value] : elements) {
+            if (!key) throw std::runtime_error("Dict key is null");
+            if (!value) throw std::runtime_error("Dict value is null");
+            result.set(key->evaluate(context), value->evaluate(context));
+        }
+        return result;
+    }
+};
+
+class SliceExpr : public Expression {
+public:
+    std::shared_ptr<Expression> start, end;
+    SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
+      : Expression(location), start(std::move(s)), end(std::move(e)) {}
+    Value do_evaluate(const std::shared_ptr<Context> &) const override {
+        throw std::runtime_error("SliceExpr not implemented");
+    }
+};
+
+class SubscriptExpr : public Expression {
+    std::shared_ptr<Expression> base;
+    std::shared_ptr<Expression> index;
+public:
+    SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
+        : Expression(location), base(std::move(b)), index(std::move(i)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!base) throw std::runtime_error("SubscriptExpr.base is null");
+        if (!index) throw std::runtime_error("SubscriptExpr.index is null");
+        auto target_value = base->evaluate(context);
+        if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
+          auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
+          auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
+          if (target_value.is_string()) {
+            std::string s = target_value.get<std::string>();
+            if (start < 0) start = s.size() + start;
+            if (end < 0) end = s.size() + end;
+            return s.substr(start, end - start);
+          } else if (target_value.is_array()) {
+            if (start < 0) start = target_value.size() + start;
+            if (end < 0) end = target_value.size() + end;
+            auto result = Value::array();
+            for (auto i = start; i < end; ++i) {
+              result.push_back(target_value.at(i));
+            }
+            return result;
+          } else {
+            throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings");
+          }
+        } else {
+          auto index_value = index->evaluate(context);
+          if (target_value.is_null()) {
+            if (auto t = dynamic_cast<VariableExpr*>(base.get())) {
+              throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined"));
+            }
+            throw std::runtime_error("Trying to access property '" +  index_value.dump() + "' on null!");
+          }
+          return target_value.get(index_value);
+        }
+    }
+};
+
+class UnaryOpExpr : public Expression {
+public:
+    enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict };
+    std::shared_ptr<Expression> expr;
+    Op op;
+    UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
+      : Expression(location), expr(std::move(e)), op(o) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
+        auto e = expr->evaluate(context);
+        switch (op) {
+            case Op::Plus: return e;
+            case Op::Minus: return -e;
+            case Op::LogicalNot: return !e.to_bool();
+            case Op::Expansion:
+            case Op::ExpansionDict:
+                throw std::runtime_error("Expansion operator is only supported in function calls and collections");
+
+        }
+        throw std::runtime_error("Unknown unary operator");
+    }
+};
+
+class BinaryOpExpr : public Expression {
+public:
+    enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
+private:
+    std::shared_ptr<Expression> left;
+    std::shared_ptr<Expression> right;
+    Op op;
+public:
+    BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
+        : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
+        if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
+        auto l = left->evaluate(context);
+
+        auto do_eval = [&](const Value & l) -> Value {
+          if (op == Op::Is || op == Op::IsNot) {
+            auto t = dynamic_cast<VariableExpr*>(right.get());
+            if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable");
+
+            auto eval = [&]() {
+              const auto & name = t->get_name();
+              if (name == "none") return l.is_null();
+              if (name == "boolean") return l.is_boolean();
+              if (name == "integer") return l.is_number_integer();
+              if (name == "float") return l.is_number_float();
+              if (name == "number") return l.is_number();
+              if (name == "string") return l.is_string();
+              if (name == "mapping") return l.is_object();
+              if (name == "iterable") return l.is_iterable();
+              if (name == "sequence") return l.is_array();
+              if (name == "defined") return !l.is_null();
+              throw std::runtime_error("Unknown type for 'is' operator: " + name);
+            };
+            auto value = eval();
+            return Value(op == Op::Is ? value : !value);
+          }
+
+          if (op == Op::And) {
+            if (!l.to_bool()) return Value(false);
+            return right->evaluate(context).to_bool();
+          } else if (op == Op::Or) {
+            if (l.to_bool()) return l;
+            return right->evaluate(context);
+          }
+
+          auto r = right->evaluate(context);
+          switch (op) {
+              case Op::StrConcat: return l.to_str() + r.to_str();
+              case Op::Add:       return l + r;
+              case Op::Sub:       return l - r;
+              case Op::Mul:       return l * r;
+              case Op::Div:       return l / r;
+              case Op::MulMul:    return std::pow(l.get<double>(), r.get<double>());
+              case Op::DivDiv:    return l.get<int64_t>() / r.get<int64_t>();
+              case Op::Mod:       return l.get<int64_t>() % r.get<int64_t>();
+              case Op::Eq:        return l == r;
+              case Op::Ne:        return l != r;
+              case Op::Lt:        return l < r;
+              case Op::Gt:        return l > r;
+              case Op::Le:        return l <= r;
+              case Op::Ge:        return l >= r;
+              case Op::In:        return (r.is_array() || r.is_object()) && r.contains(l);
+              case Op::NotIn:     return !(r.is_array() && r.contains(l));
+              default:            break;
+          }
+          throw std::runtime_error("Unknown binary operator");
+        };
+
+        if (l.is_callable()) {
+          return Value::callable([l, do_eval](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+            auto ll = l.call(context, args);
+            return do_eval(ll); //args[0].second);
+          });
+        } else {
+          return do_eval(l);
+        }
+    }
+};
+
+struct ArgumentsExpression {
+    std::vector<std::shared_ptr<Expression>> args;
+    std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
+
+    ArgumentsValue evaluate(const std::shared_ptr<Context> & context) const {
+        ArgumentsValue vargs;
+        for (const auto& arg : this->args) {
+            if (auto un_expr = std::dynamic_pointer_cast<UnaryOpExpr>(arg)) {
+                if (un_expr->op == UnaryOpExpr::Op::Expansion) {
+                    auto array = un_expr->expr->evaluate(context);
+                    if (!array.is_array()) {
+                        throw std::runtime_error("Expansion operator only supported on arrays");
+                    }
+                    array.for_each([&](Value & value) {
+                        vargs.args.push_back(value);
+                    });
+                    continue;
+                } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) {
+                    auto dict = un_expr->expr->evaluate(context);
+                    if (!dict.is_object()) {
+                        throw std::runtime_error("ExpansionDict operator only supported on objects");
+                    }
+                    dict.for_each([&](const Value & key) {
+                        vargs.kwargs.push_back({key.get<std::string>(), dict.at(key)});
+                    });
+                    continue;
+                }
+            }
+            vargs.args.push_back(arg->evaluate(context));
+        }
+        for (const auto& [name, value] : this->kwargs) {
+            vargs.kwargs.push_back({name, value->evaluate(context)});
+        }
+        return vargs;
+    }
+};
+
+static std::string strip(const std::string & s) {
+  auto start = s.find_first_not_of(" \t\n\r");
+  if (start == std::string::npos) return "";
+  auto end = s.find_last_not_of(" \t\n\r");
+  return s.substr(start, end - start + 1);
+}
+
+static std::string html_escape(const std::string & s) {
+  std::string result;
+  result.reserve(s.size());
+  for (const auto & c : s) {
+    switch (c) {
+      case '&': result += "&amp;"; break;
+      case '<': result += "&lt;"; break;
+      case '>': result += "&gt;"; break;
+      case '"': result += "&#34;"; break;
+      case '\'': result += "&apos;"; break;
+      default: result += c; break;
+    }
+  }
+  return result;
+}
+
+class MethodCallExpr : public Expression {
+    std::shared_ptr<Expression> object;
+    std::shared_ptr<VariableExpr> method;
+    ArgumentsExpression args;
+public:
+    MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, ArgumentsExpression && a)
+        : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!object) throw std::runtime_error("MethodCallExpr.object is null");
+        if (!method) throw std::runtime_error("MethodCallExpr.method is null");
+        auto obj = object->evaluate(context);
+        auto vargs = args.evaluate(context);
+        if (obj.is_null()) {
+          throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null");
+        }
+        if (obj.is_array()) {
+          if (method->get_name() == "append") {
+              vargs.expectArgs("append method", {1, 1}, {0, 0});
+              obj.push_back(vargs.args[0]);
+              return Value();
+          } else if (method->get_name() == "pop") {
+              vargs.expectArgs("pop method", {0, 1}, {0, 0});
+              return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]);
+          } else if (method->get_name() == "insert") {
+              vargs.expectArgs("insert method", {2, 2}, {0, 0});
+              auto index = vargs.args[0].get<int64_t>();
+              if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method");
+              obj.insert(index, vargs.args[1]);
+              return Value();
+          }
+        } else if (obj.is_object()) {
+          if (method->get_name() == "items") {
+            vargs.expectArgs("items method", {0, 0}, {0, 0});
+            auto result = Value::array();
+            for (const auto& key : obj.keys()) {
+              result.push_back(Value::array({key, obj.at(key)}));
+            }
+            return result;
+          } else if (method->get_name() == "pop") {
+            vargs.expectArgs("pop method", {1, 1}, {0, 0});
+            return obj.pop(vargs.args[0]);
+          } else if (method->get_name() == "get") {
+            vargs.expectArgs("get method", {1, 2}, {0, 0});
+            auto key = vargs.args[0];
+            if (vargs.args.size() == 1) {
+              return obj.contains(key) ? obj.at(key) : Value();
+            } else {
+              return obj.contains(key) ? obj.at(key) : vargs.args[1];
+            }
+          } else if (obj.contains(method->get_name())) {
+            auto callable = obj.at(method->get_name());
+            if (!callable.is_callable()) {
+              throw std::runtime_error("Property '" + method->get_name() + "' is not callable");
+            }
+            return callable.call(context, vargs);
+          }
+        } else if (obj.is_string()) {
+          auto str = obj.get<std::string>();
+          if (method->get_name() == "strip") {
+            vargs.expectArgs("strip method", {0, 0}, {0, 0});
+            return Value(strip(str));
+          } else if (method->get_name() == "endswith") {
+            vargs.expectArgs("endswith method", {1, 1}, {0, 0});
+            auto suffix = vargs.args[0].get<std::string>();
+            return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
+          } else if (method->get_name() == "title") {
+            vargs.expectArgs("title method", {0, 0}, {0, 0});
+            auto res = str;
+            for (size_t i = 0, n = res.size(); i < n; ++i) {
+              if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]);
+              else res[i] = std::tolower(res[i]);
+            }
+            return res;
+          }
+        }
+        throw std::runtime_error("Unknown method: " + method->get_name());
+    }
+};
+
+class CallExpr : public Expression {
+public:
+    std::shared_ptr<Expression> object;
+    ArgumentsExpression args;
+    CallExpr(const Location & location, std::shared_ptr<Expression> && obj, ArgumentsExpression && a)
+        : Expression(location), object(std::move(obj)), args(std::move(a)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        if (!object) throw std::runtime_error("CallExpr.object is null");
+        auto obj = object->evaluate(context);
+        if (!obj.is_callable()) {
+          throw std::runtime_error("Object is not callable: " + obj.dump(2));
+        }
+        auto vargs = args.evaluate(context);
+        return obj.call(context, vargs);
+    }
+};
+
+class FilterExpr : public Expression {
+    std::vector<std::shared_ptr<Expression>> parts;
+public:
+    FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
+      : Expression(location), parts(std::move(p)) {}
+    Value do_evaluate(const std::shared_ptr<Context> & context) const override {
+        Value result;
+        bool first = true;
+        for (const auto& part : parts) {
+          if (!part) throw std::runtime_error("FilterExpr.part is null");
+          if (first) {
+            first = false;
+            result = part->evaluate(context);
+          } else {
+            if (auto ce = dynamic_cast<CallExpr*>(part.get())) {
+              auto target = ce->object->evaluate(context);
+              ArgumentsValue args = ce->args.evaluate(context);
+              args.args.insert(args.args.begin(), result);
+              result = target.call(context, args);
+            } else {
+              auto callable = part->evaluate(context);
+              ArgumentsValue args;
+              args.args.insert(args.args.begin(), result);
+              result = callable.call(context, args);
+            }
+          }
+        }
+        return result;
+    }
+
+    void prepend(std::shared_ptr<Expression> && e) {
+        parts.insert(parts.begin(), std::move(e));
+    }
+};
+
+class Parser {
+private:
+    using CharIterator = std::string::const_iterator;
+
+    std::shared_ptr<std::string> template_str;
+    CharIterator start, end, it;
+    Options options;
+
+    Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
+      if (!template_str) throw std::runtime_error("Template string is null");
+      start = it = this->template_str->begin();
+      end = this->template_str->end();
+    }
+
+    bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) {
+      if (space_handling == SpaceHandling::Strip) {
+        while (it != end && std::isspace(*it)) ++it;
+      }
+      return true;
+    }
+
+    std::unique_ptr<std::string> parseString() {
+      auto doParse = [&](char quote) -> std::unique_ptr<std::string> {
+        if (it == end || *it != quote) return nullptr;
+        std::string result;
+        bool escape = false;
+        for (++it; it != end; ++it) {
+          if (escape) {
+            escape = false;
+            switch (*it) {
+              case 'n': result += '\n'; break;
+              case 'r': result += '\r'; break;
+              case 't': result += '\t'; break;
+              case 'b': result += '\b'; break;
+              case 'f': result += '\f'; break;
+              case '\\': result += '\\'; break;
+              default:
+                if (*it == quote) {
+                  result += quote;
+                } else {
+                  result += *it;
+                }
+                break;
+            }
+          } else if (*it == '\\') {
+            escape = true;
+          } else if (*it == quote) {
+              ++it;
+            return std::make_unique<std::string>(std::move(result));
+          } else {
+            result += *it;
+          }
+        }
+        return nullptr;
+      };
+
+      consumeSpaces();
+      if (it == end) return nullptr;
+      if (*it == '"') return doParse('"');
+      if (*it == '\'') return doParse('\'');
+      return nullptr;
+    }
+
+    json parseNumber(CharIterator& it, const CharIterator& end) {
+        auto before = it;
+        consumeSpaces();
+        auto start = it;
+        bool hasDecimal = false;
+        bool hasExponent = false;
+
+        if (it != end && (*it == '-' || *it == '+')) ++it;
+
+        while (it != end) {
+          if (std::isdigit(*it)) {
+            ++it;
+          } else if (*it == '.') {
+            if (hasDecimal) throw std::runtime_error("Multiple decimal points");
+            hasDecimal = true;
+            ++it;
+          } else if (it != start && (*it == 'e' || *it == 'E')) {
+            if (hasExponent) throw std::runtime_error("Multiple exponents");
+            hasExponent = true;
+            ++it;
+          } else {
+            break;
+          }
+        }
+        if (start == it) {
+          it = before;
+          return json(); // No valid characters found
+        }
+
+        std::string str(start, it);
+        try {
+          return json::parse(str);
+        } catch (json::parse_error& e) {
+          throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")");
+          return json();
+        }
+    }
+
+    /** integer, float, bool, string */
+    std::shared_ptr<Value> parseConstant() {
+      auto start = it;
+      consumeSpaces();
+      if (it == end) return nullptr;
+      if (*it == '"' || *it == '\'') {
+        auto str = parseString();
+        if (str) return std::make_shared<Value>(*str);
+      }
+      static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
+      auto token = consumeToken(prim_tok);
+      if (!token.empty()) {
+        if (token == "true" || token == "True") return std::make_shared<Value>(true);
+        if (token == "false" || token == "False") return std::make_shared<Value>(false);
+        if (token == "None") return std::make_shared<Value>(nullptr);
+        throw std::runtime_error("Unknown constant token: " + token);
+      }
+
+      auto number = parseNumber(it, end);
+      if (!number.is_null()) return std::make_shared<Value>(number);
+
+      it = start;
+      return nullptr;
+    }
+
+    class expression_parsing_error : public std::runtime_error {
+        const CharIterator it;
+      public:
+        expression_parsing_error(const std::string & message, const CharIterator & it)
+            : std::runtime_error(message), it(it) {}
+        size_t get_pos(const CharIterator & begin) const {
+            return std::distance(begin, it);
+      }
+    };
+
+    bool peekSymbols(const std::vector<std::string> & symbols) const {
+        for (const auto & symbol : symbols) {
+            if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    std::vector<std::string> consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        std::smatch match;
+        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+            it += match[0].length();
+            std::vector<std::string> ret;
+            for (size_t i = 0, n = match.size(); i < n; ++i) {
+                ret.push_back(match[i].str());
+            }
+            return ret;
+        }
+        it = start;
+        return {};
+    }
+    std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        std::smatch match;
+        if (std::regex_search(it, end, match, regex) && match.position() == 0) {
+            it += match[0].length();
+            return match[0].str();
+        }
+        it = start;
+        return "";
+    }
+
+    std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) {
+        auto start = it;
+        consumeSpaces(space_handling);
+        if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) {
+            it += token.size();
+            return token;
+        }
+        it = start;
+        return "";
+    }
+
+    std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
+        auto left = parseLogicalOr();
+        if (it == end) return left;
+
+        if (!allow_if_expr) return left;
+
+        static std::regex if_tok(R"(if\b)");
+        if (consumeToken(if_tok).empty()) {
+          return left;
+        }
+
+        auto location = get_location();
+        auto [condition, else_expr] = parseIfExpression();
+        return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
+    }
+
+    Location get_location() const {
+        return {template_str, (size_t) std::distance(start, it)};
+    }
+
+    std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
+        auto condition = parseLogicalOr();
+        if (!condition) throw std::runtime_error("Expected condition expression");
+
+        static std::regex else_tok(R"(else\b)");
+        std::shared_ptr<Expression> else_expr;
+        if (!consumeToken(else_tok).empty()) {
+          else_expr = parseExpression();
+          if (!else_expr) throw std::runtime_error("Expected 'else' expression");
+        }
+        return std::pair(std::move(condition), std::move(else_expr));
+    }
+
+    std::shared_ptr<Expression> parseLogicalOr() {
+        auto left = parseLogicalAnd();
+        if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
+
+        static std::regex or_tok(R"(or\b)");
+        auto location = get_location();
+        while (!consumeToken(or_tok).empty()) {
+            auto right = parseLogicalAnd();
+            if (!right) throw std::runtime_error("Expected right side of 'or' expression");
+            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseLogicalNot() {
+        static std::regex not_tok(R"(not\b)");
+        auto location = get_location();
+
+        if (!consumeToken(not_tok).empty()) {
+          auto sub = parseLogicalNot();
+          if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
+          return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
+        }
+        return parseLogicalCompare();
+    }
+
+    std::shared_ptr<Expression> parseLogicalAnd() {
+        auto left = parseLogicalNot();
+        if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
+
+        static std::regex and_tok(R"(and\b)");
+        auto location = get_location();
+        while (!consumeToken(and_tok).empty()) {
+            auto right = parseLogicalNot();
+            if (!right) throw std::runtime_error("Expected right side of 'and' expression");
+            left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseLogicalCompare() {
+        auto left = parseStringConcat();
+        if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
+
+        static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
+        static std::regex not_tok(R"(not\b)");
+        std::string op_str;
+        while (!(op_str = consumeToken(compare_tok)).empty()) {
+            auto location = get_location();
+            if (op_str == "is") {
+              auto negated = !consumeToken(not_tok).empty();
+
+              auto identifier = parseIdentifier();
+              if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
+
+              return std::make_shared<BinaryOpExpr>(
+                  left->location,
+                  std::move(left), std::move(identifier),
+                  negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
+            }
+            auto right = parseStringConcat();
+            if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression");
+            BinaryOpExpr::Op op;
+            if (op_str == "==") op = BinaryOpExpr::Op::Eq;
+            else if (op_str == "!=") op = BinaryOpExpr::Op::Ne;
+            else if (op_str == "<") op = BinaryOpExpr::Op::Lt;
+            else if (op_str == ">") op = BinaryOpExpr::Op::Gt;
+            else if (op_str == "<=") op = BinaryOpExpr::Op::Le;
+            else if (op_str == ">=") op = BinaryOpExpr::Op::Ge;
+            else if (op_str == "in") op = BinaryOpExpr::Op::In;
+            else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
+            else throw std::runtime_error("Unknown comparison operator: " + op_str);
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+        return left;
+    }
+
+    Expression::Parameters parseParameters() {
+        consumeSpaces();
+        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
+
+        Expression::Parameters result;
+
+        while (it != end) {
+            if (!consumeToken(")").empty()) {
+                return result;
+            }
+            auto expr = parseExpression();
+            if (!expr) throw std::runtime_error("Expected expression in call args");
+
+            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+                if (!consumeToken("=").empty()) {
+                    auto value = parseExpression();
+                    if (!value) throw std::runtime_error("Expected expression in for named arg");
+                    result.emplace_back(ident->get_name(), std::move(value));
+                } else {
+                    result.emplace_back(ident->get_name(), nullptr);
+                }
+            } else {
+                result.emplace_back(std::string(), std::move(expr));
+            }
+            if (consumeToken(",").empty()) {
+              if (consumeToken(")").empty()) {
+                throw std::runtime_error("Expected closing parenthesis in call args");
+              }
+              return result;
+            }
+        }
+        throw std::runtime_error("Expected closing parenthesis in call args");
+    }
+
+    ArgumentsExpression parseCallArgs() {
+        consumeSpaces();
+        if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
+
+        ArgumentsExpression result;
+
+        while (it != end) {
+            if (!consumeToken(")").empty()) {
+                return result;
+            }
+            auto expr = parseExpression();
+            if (!expr) throw std::runtime_error("Expected expression in call args");
+
+            if (auto ident = dynamic_cast<VariableExpr*>(expr.get())) {
+                if (!consumeToken("=").empty()) {
+                    auto value = parseExpression();
+                    if (!value) throw std::runtime_error("Expected expression in for named arg");
+                    result.kwargs.emplace_back(ident->get_name(), std::move(value));
+                } else {
+                    result.args.emplace_back(std::move(expr));
+                }
+            } else {
+                result.args.emplace_back(std::move(expr));
+            }
+            if (consumeToken(",").empty()) {
+              if (consumeToken(")").empty()) {
+                throw std::runtime_error("Expected closing parenthesis in call args");
+              }
+              return result;
+            }
+        }
+        throw std::runtime_error("Expected closing parenthesis in call args");
+    }
+
+    std::shared_ptr<VariableExpr> parseIdentifier() {
+        static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
+        auto location = get_location();
+        auto ident = consumeToken(ident_regex);
+        if (ident.empty())
+          return nullptr;
+        return std::make_shared<VariableExpr>(location, ident);
+    }
+
+    std::shared_ptr<Expression> parseStringConcat() {
+        auto left = parseMathPow();
+        if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
+
+        static std::regex concat_tok(R"(~(?!\}))");
+        if (!consumeToken(concat_tok).empty()) {
+            auto right = parseLogicalAnd();
+            if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathPow() {
+        auto left = parseMathPlusMinus();
+        if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
+
+        while (!consumeToken("**").empty()) {
+            auto right = parseMathPlusMinus();
+            if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathPlusMinus() {
+        static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
+
+        auto left = parseMathMulDiv();
+        if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression");
+        std::string op_str;
+        while (!(op_str = consumeToken(plus_minus_tok)).empty()) {
+            auto right = parseMathMulDiv();
+            if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
+            auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> parseMathMulDiv() {
+        auto left = parseMathUnaryPlusMinus();
+        if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
+
+        static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))");
+        std::string op_str;
+        while (!(op_str = consumeToken(mul_div_tok)).empty()) {
+            auto right = parseMathUnaryPlusMinus();
+            if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
+            auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
+                : op_str == "**" ? BinaryOpExpr::Op::MulMul
+                : op_str == "/" ? BinaryOpExpr::Op::Div
+                : op_str == "//" ? BinaryOpExpr::Op::DivDiv
+                : BinaryOpExpr::Op::Mod;
+            left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
+        }
+
+        if (!consumeToken("|").empty()) {
+            auto expr = parseMathMulDiv();
+            if (auto filter = dynamic_cast<FilterExpr*>(expr.get())) {
+                filter->prepend(std::move(left));
+                return expr;
+            } else {
+                std::vector<std::shared_ptr<Expression>> parts;
+                parts.emplace_back(std::move(left));
+                parts.emplace_back(std::move(expr));
+                return std::make_shared<FilterExpr>(get_location(), std::move(parts));
+            }
+        }
+        return left;
+    }
+
+    std::shared_ptr<Expression> call_func(const std::string & name, ArgumentsExpression && args) const {
+        return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
+    }
+
+    std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
+        static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
+        auto op_str = consumeToken(unary_plus_minus_tok);
+        auto expr = parseExpansion();
+        if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression");
+
+        if (!op_str.empty()) {
+            auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
+            return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
+        }
+        return expr;
+    }
+
+    std::shared_ptr<Expression> parseExpansion() {
+      static std::regex expansion_tok(R"(\*\*?)");
+      auto op_str = consumeToken(expansion_tok);
+      auto expr = parseValueExpression();
+      if (op_str.empty()) return expr;
+      if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression");
+      return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict);
+    }
+
+    std::shared_ptr<Expression> parseValueExpression() {
+      auto parseValue = [&]() -> std::shared_ptr<Expression> {
+        auto location = get_location();
+        auto constant = parseConstant();
+        if (constant) return std::make_shared<LiteralExpr>(location, *constant);
+
+        static std::regex null_regex(R"(null\b)");
+        if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
+
+        auto identifier = parseIdentifier();
+        if (identifier) return identifier;
+
+        auto braced = parseBracedExpressionOrArray();
+        if (braced) return braced;
+
+        auto array = parseArray();
+        if (array) return array;
+
+        auto dictionary = parseDictionary();
+        if (dictionary) return dictionary;
+
+        throw std::runtime_error("Expected value expression");
+      };
+
+      auto value = parseValue();
+
+      while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
+        if (!consumeToken("[").empty()) {
+            std::shared_ptr<Expression> index;
+            if (!consumeToken(":").empty()) {
+              auto slice_end = parseExpression();
+              index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
+            } else {
+              auto slice_start = parseExpression();
+              if (!consumeToken(":").empty()) {
+                consumeSpaces();
+                if (peekSymbols({ "]" })) {
+                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
+                } else {
+                  auto slice_end = parseExpression();
+                  index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
+                }
+              } else {
+                index = std::move(slice_start);
+              }
+            }
+            if (!index) throw std::runtime_error("Empty index in subscript");
+            if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
+
+            value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
+        } else if (!consumeToken(".").empty()) {
+            auto identifier = parseIdentifier();
+            if (!identifier) throw std::runtime_error("Expected identifier in subscript");
+
+            consumeSpaces();
+            if (peekSymbols({ "(" })) {
+              auto callParams = parseCallArgs();
+              value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
+            } else {
+              auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
+              value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
+            }
+        }
+        consumeSpaces();
+      }
+
+      if (peekSymbols({ "(" })) {
+        auto location = get_location();
+        auto callParams = parseCallArgs();
+        value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
+      }
+      return value;
+    }
+
+    std::shared_ptr<Expression> parseBracedExpressionOrArray() {
+        if (consumeToken("(").empty()) return nullptr;
+
+        auto expr = parseExpression();
+        if (!expr) throw std::runtime_error("Expected expression in braced expression");
+
+        if (!consumeToken(")").empty()) {
+            return expr;  // Drop the parentheses
+        }
+
+        std::vector<std::shared_ptr<Expression>> tuple;
+        tuple.emplace_back(std::move(expr));
+
+        while (it != end) {
+          if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple");
+          auto next = parseExpression();
+          if (!next) throw std::runtime_error("Expected expression in tuple");
+          tuple.push_back(std::move(next));
+
+          if (!consumeToken(")").empty()) {
+              return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
+          }
+        }
+        throw std::runtime_error("Expected closing parenthesis");
+    }
+
+    std::shared_ptr<Expression> parseArray() {
+        if (consumeToken("[").empty()) return nullptr;
+
+        std::vector<std::shared_ptr<Expression>> elements;
+        if (!consumeToken("]").empty()) {
+            return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+        }
+        auto first_expr = parseExpression();
+        if (!first_expr) throw std::runtime_error("Expected first expression in array");
+        elements.push_back(std::move(first_expr));
+
+        while (it != end) {
+            if (!consumeToken(",").empty()) {
+              auto expr = parseExpression();
+              if (!expr) throw std::runtime_error("Expected expression in array");
+              elements.push_back(std::move(expr));
+            } else if (!consumeToken("]").empty()) {
+                return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
+            } else {
+                throw std::runtime_error("Expected comma or closing bracket in array");
+            }
+        }
+        throw std::runtime_error("Expected closing bracket");
+    }
+
+    std::shared_ptr<Expression> parseDictionary() {
+        if (consumeToken("{").empty()) return nullptr;
+
+        std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
+        if (!consumeToken("}").empty()) {
+            return std::make_shared<DictExpr>(get_location(), std::move(elements));
+        }
+
+        auto parseKeyValuePair = [&]() {
+            auto key = parseExpression();
+            if (!key) throw std::runtime_error("Expected key in dictionary");
+            if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
+            auto value = parseExpression();
+            if (!value) throw std::runtime_error("Expected value in dictionary");
+            elements.emplace_back(std::pair(std::move(key), std::move(value)));
+        };
+
+        parseKeyValuePair();
+
+        while (it != end) {
+            if (!consumeToken(",").empty()) {
+                parseKeyValuePair();
+            } else if (!consumeToken("}").empty()) {
+                return std::make_shared<DictExpr>(get_location(), std::move(elements));
+            } else {
+                throw std::runtime_error("Expected comma or closing brace in dictionary");
+            }
+        }
+        throw std::runtime_error("Expected closing brace");
+    }
+
+    SpaceHandling parsePreSpace(const std::string& s) const {
+        if (s == "-")
+          return SpaceHandling::Strip;
+        return SpaceHandling::Keep;
+    }
+
+    SpaceHandling parsePostSpace(const std::string& s) const {
+        if (s == "-") return SpaceHandling::Strip;
+        return SpaceHandling::Keep;
+    }
+
+    using TemplateTokenVector = std::vector<std::unique_ptr<TemplateToken>>;
+    using TemplateTokenIterator = TemplateTokenVector::const_iterator;
+
+    std::vector<std::string> parseVarNames() {
+      static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
+
+      std::vector<std::string> group;
+      if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
+      std::vector<std::string> varnames;
+      std::istringstream iss(group[1]);
+      std::string varname;
+      while (std::getline(iss, varname, ',')) {
+        varnames.push_back(strip(varname));
+      }
+      return varnames;
+    }
+
+    std::runtime_error unexpected(const TemplateToken & token) const {
+      return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type)
+        + error_location_suffix(*template_str, token.location.pos));
+    }
+    std::runtime_error unterminated(const TemplateToken & token) const {
+      return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type)
+        + error_location_suffix(*template_str, token.location.pos));
+    }
+
+    TemplateTokenVector tokenize() {
+      static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
+      static std::regex expr_open_regex(R"(\{\{([-~])?)");
+      static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
+      static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
+      static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
+      static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
+      static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
+
+      TemplateTokenVector tokens;
+      std::vector<std::string> group;
+      std::string text;
+      std::smatch match;
+
+      try {
+        while (it != end) {
+          auto location = get_location();
+
+          if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+            auto content = group[2];
+            auto post_space = parsePostSpace(group[3]);
+            tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
+          } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+            auto expr = parseExpression();
+
+            if ((group = consumeTokenGroups(expr_close_regex)).empty()) {
+              throw std::runtime_error("Expected closing expression tag");
+            }
+
+            auto post_space = parsePostSpace(group[1]);
+            tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
+          } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
+            auto pre_space = parsePreSpace(group[1]);
+
+            std::string keyword;
+
+            auto parseBlockClose = [&]() -> SpaceHandling {
+              if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag");
+              return parsePostSpace(group[1]);
+            };
+
+            if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
+
+            if (keyword == "if") {
+              auto condition = parseExpression();
+              if (!condition) throw std::runtime_error("Expected condition in if block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
+            } else if (keyword == "elif") {
+              auto condition = parseExpression();
+              if (!condition) throw std::runtime_error("Expected condition in elif block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
+            } else if (keyword == "else") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "endif") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "for") {
+              static std::regex recursive_tok(R"(recursive\b)");
+              static std::regex if_tok(R"(if\b)");
+
+              auto varnames = parseVarNames();
+              static std::regex in_tok(R"(in\b)");
+              if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block");
+              auto iterable = parseExpression(/* allow_if_expr = */ false);
+              if (!iterable) throw std::runtime_error("Expected iterable in for block");
+
+              std::shared_ptr<Expression> condition;
+              if (!consumeToken(if_tok).empty()) {
+                condition = parseExpression();
+              }
+              auto recursive = !consumeToken(recursive_tok).empty();
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
+            } else if (keyword == "endfor") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "set") {
+              static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
+
+              std::string ns;
+              std::vector<std::string> var_names;
+              std::shared_ptr<Expression> value;
+              if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
+                ns = group[1];
+                var_names.push_back(group[2]);
+
+                if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
+
+                value = parseExpression();
+                if (!value) throw std::runtime_error("Expected value in set block");
+              } else {
+                var_names = parseVarNames();
+
+                if (!consumeToken("=").empty()) {
+                  value = parseExpression();
+                  if (!value) throw std::runtime_error("Expected value in set block");
+                }
+              }
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
+            } else if (keyword == "endset") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "macro") {
+              auto macroname = parseIdentifier();
+              if (!macroname) throw std::runtime_error("Expected macro name in macro block");
+              auto params = parseParameters();
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
+            } else if (keyword == "endmacro") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
+            } else if (keyword == "filter") {
+              auto filter = parseExpression();
+              if (!filter) throw std::runtime_error("Expected expression in filter block");
+
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
+            } else if (keyword == "endfilter") {
+              auto post_space = parseBlockClose();
+              tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
+            } else {
+              throw std::runtime_error("Unexpected block: " + keyword);
+            }
+          } else if (std::regex_search(it, end, match, non_text_open_regex)) {
+            auto text_end = it + match.position();
+            text = std::string(it, text_end);
+            it = text_end;
+            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+          } else {
+            text = std::string(it, end);
+            it = end;
+            tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
+          }
+        }
+        return tokens;
+      } catch (const std::exception & e) {
+        throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it)));
+      }
+    }
+
+    std::shared_ptr<TemplateNode> parseTemplate(
+          const TemplateTokenIterator & begin,
+          TemplateTokenIterator & it,
+          const TemplateTokenIterator & end,
+          bool fully = false) const {
+        std::vector<std::shared_ptr<TemplateNode>> children;
+        while (it != end) {
+          const auto start = it;
+          const auto & token = *(it++);
+          if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
+              std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
+              cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
+
+              while (it != end && (*it)->type == TemplateToken::Type::Elif) {
+                  auto elif_token = dynamic_cast<ElifTemplateToken*>((*(it++)).get());
+                  cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end));
+              }
+
+              if (it != end && (*it)->type == TemplateToken::Type::Else) {
+                cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end));
+              }
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
+          } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              auto else_body = std::shared_ptr<TemplateNode>();
+              if (it != end && (*it)->type == TemplateToken::Type::Else) {
+                else_body = parseTemplate(begin, ++it, end);
+              }
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
+          } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
+              SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
+              SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
+
+              auto text = text_token->text;
+              if (post_space == SpaceHandling::Strip) {
+                static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
+                text = std::regex_replace(text, trailing_space_regex, "");
+              } else if (options.lstrip_blocks && it != end) {
+                auto i = text.size();
+                while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
+                if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
+                  text.resize(i);
+                }
+              }
+              if (pre_space == SpaceHandling::Strip) {
+                static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
+                text = std::regex_replace(text, leading_space_regex, "");
+              } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
+                if (text.length() > 0 && text[0] == '\n') {
+                  text.erase(0, 1);
+                }
+              }
+              if (it == end && !options.keep_trailing_newline) {
+                auto i = text.size();
+                if (i > 0 && text[i - 1] == '\n') {
+                  i--;
+                  if (i > 0 && text[i - 1] == '\r') i--;
+                  text.resize(i);
+                }
+              }
+              children.emplace_back(std::make_shared<TextNode>(token->location, text));
+          } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
+              children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
+          } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
+            if (set_token->value) {
+              children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
+            } else {
+              auto value_template = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
+                  throw unterminated(**start);
+              }
+              if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
+              if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
+              auto & name = set_token->var_names[0];
+              children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
+            }
+          } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
+          } else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
+              auto body = parseTemplate(begin, it, end);
+              if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
+                  throw unterminated(**start);
+              }
+              children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
+          } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
+              // Ignore comments
+          } else if (dynamic_cast<EndForTemplateToken*>(token.get())
+                  || dynamic_cast<EndSetTemplateToken*>(token.get())
+                  || dynamic_cast<EndMacroTemplateToken*>(token.get())
+                  || dynamic_cast<EndFilterTemplateToken*>(token.get())
+                  || dynamic_cast<EndIfTemplateToken*>(token.get())
+                  || dynamic_cast<ElseTemplateToken*>(token.get())
+                  || dynamic_cast<ElifTemplateToken*>(token.get())) {
+              it--;  // unconsume the token
+              break;  // exit the loop
+          } else {
+              throw unexpected(**(it-1));
+          }
+        }
+        if (fully && it != end) {
+            throw unexpected(**it);
+        }
+        if (children.empty()) {
+          return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
+        } else if (children.size() == 1) {
+          return std::move(children[0]);
+        } else {
+          return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
+        }
+    }
+
+public:
+
+    static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
+        Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
+        auto tokens = parser.tokenize();
+        TemplateTokenIterator begin = tokens.begin();
+        auto it = begin;
+        TemplateTokenIterator end = tokens.end();
+        return parser.parseTemplate(begin, it, end, /* full= */ true);
+    }
+};
+
+static Value simple_function(const std::string & fn_name, const std::vector<std::string> & params, const std::function<Value(const std::shared_ptr<Context> &, Value & args)> & fn) {
+  std::map<std::string, size_t> named_positions;
+  for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i;
+
+  return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) -> Value {
+    auto args_obj = Value::object();
+    std::vector<bool> provided_args(params.size());
+    for (size_t i = 0, n = args.args.size(); i < n; i++) {
+      auto & arg = args.args[i];
+      if (i < params.size()) {
+        args_obj.set(params[i], arg);
+        provided_args[i] = true;
+      } else {
+        throw std::runtime_error("Too many positional params for " + fn_name);
+      }
+    }
+    for (auto & [name, value] : args.kwargs) {
+      auto named_pos_it = named_positions.find(name);
+      if (named_pos_it == named_positions.end()) {
+        throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
+      }
+      provided_args[named_pos_it->second] = true;
+      args_obj.set(name, value);
+    }
+    return fn(context, args_obj);
+  });
+}
+
+inline std::shared_ptr<Context> Context::builtins() {
+  auto globals = Value::object();
+
+  globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+    throw std::runtime_error(args.at("message").get<std::string>());
+  }));
+  globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* tojson= */ true));
+  }));
+  globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto items = Value::array();
+    if (args.contains("object")) {
+      auto & obj = args.at("object");
+      if (obj.is_string()) {
+        auto json_obj = json::parse(obj.get<std::string>());
+        for (const auto & kv : json_obj.items()) {
+          items.push_back(Value::array({kv.key(), kv.value()}));
+        }
+      } else if (!obj.is_null()) {
+        for (auto & key : obj.keys()) {
+          items.push_back(Value::array({key, obj.at(key)}));
+        }
+      }
+    }
+    return items;
+  }));
+  globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto items = args.at("items");
+    if (!items.is_array()) throw std::runtime_error("object is not a list");
+    if (items.size() == 0) return Value();
+    return items.at(items.size() - 1);
+  }));
+  globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto & text = args.at("text");
+    return text.is_null() ? text : Value(strip(text.get<std::string>()));
+  }));
+  globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto text = args.at("text");
+    if (text.is_null()) return text;
+    std::string res;
+    auto str = text.get<std::string>();
+    std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower);
+    return Value(res);
+  }));
+  globals.set("default", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    args.expectArgs("default", {2, 3}, {0, 1});
+    auto & value = args.args[0];
+    auto & default_value = args.args[1];
+    bool boolean = false;
+    if (args.args.size() == 3) {
+      boolean = args.args[2].get<bool>();
+    } else {
+      Value bv = args.get_named("boolean");
+      if (!bv.is_null()) {
+        boolean = bv.get<bool>();
+      }
+    }
+    return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value;
+  }));
+  auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value(html_escape(args.at("text").get<std::string>()));
+  });
+  globals.set("e", escape);
+  globals.set("escape", escape);
+  globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto sep = args.get<std::string>("sep", "");
+    auto first = std::make_shared<bool>(true);
+    return simple_function("", {}, [sep, first](const std::shared_ptr<Context> &, const Value &) -> Value {
+      if (*first) {
+        *first = false;
+        return "";
+      }
+      return sep;
+    });
+    return Value(html_escape(args.at("text").get<std::string>()));
+  }));
+  globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
+    return Value((int64_t) args.at("items").size());
+  }));
+  globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr<Context> &, Value & args) {
+    if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)");
+    auto & value = args.at("value");
+    auto keys = value.keys();
+    std::sort(keys.begin(), keys.end());
+    auto res = Value::array();
+    for (auto & key : keys) {
+      res.push_back(Value::array({key, value.at(key)}));
+    }
+    return res;
+  }));
+  globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto do_join = [](Value & items, const std::string & sep) {
+      std::ostringstream oss;
+      auto first = true;
+      for (size_t i = 0, n = items.size(); i < n; ++i) {
+        if (first) first = false;
+        else oss << sep;
+        oss << items.at(i).to_str();
+      }
+      return Value(oss.str());
+    };
+    auto sep = args.get<std::string>("d", "");
+    if (args.contains("items")) {
+        auto & items = args.at("items");
+        return do_join(items, sep);
+    } else {
+      return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr<Context> &, Value & args) {
+        auto & items = args.at("items");
+        if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump());
+        return do_join(items, sep);
+      });
+    }
+  }));
+  globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    auto ns = Value::object();
+    args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits<size_t>::max)()});
+    for (auto & [name, value] : args.kwargs) {
+      ns.set(name, value);
+    }
+    return ns;
+  }));
+  auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("actual") == args.at("expected");
+  });
+  globals.set("equalto", equalto);
+  globals.set("==", equalto);
+  globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      return (int64_t) items.size();
+  }));
+  globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_str();
+  }));
+  globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_str();
+  }));
+  globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      return args.at("value").to_int();
+  }));
+  globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      if (!items.is_array()) throw std::runtime_error("object is not iterable");
+      return items;
+  }));
+  globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
+      auto & items = args.at("items");
+      if (!items.is_array()) throw std::runtime_error("object is not iterable");
+      std::unordered_set<Value> seen;
+      auto result = Value::array();
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto pair = seen.insert(items.at(i));
+        if (pair.second) {
+          result.push_back(items.at(i));
+        }
+      }
+      return result;
+  }));
+  auto make_filter = [](const Value & filter, Value & extra_args) -> Value {
+    return simple_function("", { "value" }, [=](const std::shared_ptr<Context> & context, Value & args) {
+      auto & value = args.at("value");
+      ArgumentsValue actual_args;
+      actual_args.args.emplace_back(value);
+      for (size_t i = 0, n = extra_args.size(); i < n; i++) {
+        actual_args.args.emplace_back(extra_args.at(i));
+      }
+      return filter.call(context, actual_args);
+    });
+  };
+  // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
+  globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+    args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+    auto & items = args.args[0];
+    auto filter_fn = context->get(args.args[1]);
+    if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+
+    auto filter_args = Value::array();
+    for (size_t i = 2, n = args.args.size(); i < n; i++) {
+      filter_args.push_back(args.args[i]);
+    }
+    auto filter = make_filter(filter_fn, filter_args);
+
+    auto res = Value::array();
+    for (size_t i = 0, n = items.size(); i < n; i++) {
+      auto & item = items.at(i);
+      ArgumentsValue filter_args;
+      filter_args.args.emplace_back(item);
+      auto pred_res = filter.call(context, filter_args);
+      if (!pred_res.to_bool()) {
+        res.push_back(item);
+      }
+    }
+    return res;
+  }));
+  globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+    auto res = Value::array();
+    if (args.args.size() == 1 &&
+      ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) {
+      auto & items = args.args[0];
+      auto attr_name = args.get_named("attribute");
+      auto default_value = args.get_named("default");
+      for (size_t i = 0, n = items.size(); i < n; i++) {
+        auto & item = items.at(i);
+        auto attr = item.get(attr_name);
+        res.push_back(attr.is_null() ? default_value : attr);
+      }
+    } else if (args.kwargs.empty() && args.args.size() >= 2) {
+      auto fn = context->get(args.args[1]);
+      if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
+      ArgumentsValue filter_args { {Value()}, {} };
+      for (size_t i = 2, n = args.args.size(); i < n; i++) {
+        filter_args.args.emplace_back(args.args[i]);
+      }
+      for (size_t i = 0, n = args.args[0].size(); i < n; i++) {
+        auto & item = args.args[0].at(i);
+        filter_args.args[0] = item;
+        res.push_back(fn.call(context, filter_args));
+      }
+    } else {
+      throw std::runtime_error("Invalid or unsupported arguments for map");
+    }
+    return res;
+  }));
+  globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr<Context> &, Value & args) {
+    auto text = args.at("text").get<std::string>();
+    auto first = args.get<bool>("first", false);
+    std::string out;
+    std::string indent(args.get<int64_t>("indent", 0), ' ');
+    std::istringstream iss(text);
+    std::string line;
+    auto is_first = true;
+    while (std::getline(iss, line, '\n')) {
+      auto needs_indent = !is_first || first;
+      if (is_first) is_first = false;
+      else out += "\n";
+      if (needs_indent) out += indent;
+      out += line;
+    }
+    if (!text.empty() && text.back() == '\n') out += "\n";
+    return out;
+  }));
+  globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
+    args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
+    auto & items = args.args[0];
+    if (items.is_null())
+      return Value::array();
+    auto attr_name = args.args[1].get<std::string>();
+
+    bool has_test = false;
+    Value test_fn;
+    ArgumentsValue test_args {{Value()}, {}};
+    if (args.args.size() >= 3) {
+      has_test = true;
+      test_fn = context->get(args.args[2]);
+      if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
+      for (size_t i = 3, n = args.args.size(); i < n; i++) {
+        test_args.args.emplace_back(args.args[i]);
+      }
+      test_args.kwargs = args.kwargs;
+    }
+
+    auto res = Value::array();
+    for (size_t i = 0, n = items.size(); i < n; i++) {
+      auto & item = items.at(i);
+      auto attr = item.get(attr_name);
+      if (has_test) {
+        test_args.args[0] = attr;
+        if (test_fn.call(context, test_args).to_bool()) {
+          res.push_back(item);
+        }
+      } else {
+        res.push_back(attr);
+      }
+    }
+    return res;
+  }));
+  globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
+    std::vector<int64_t> startEndStep(3);
+    std::vector<bool> param_set(3);
+    if (args.args.size() == 1) {
+      startEndStep[1] = args.args[0].get<int64_t>();
+      param_set[1] = true;
+    } else {
+      for (size_t i = 0; i < args.args.size(); i++) {
+        auto & arg = args.args[i];
+        auto v = arg.get<int64_t>();
+        startEndStep[i] = v;
+        param_set[i] = true;
+        }
+      }
+      for (auto & [name, value] : args.kwargs) {
+        size_t i;
+        if (name == "start") i = 0;
+        else if (name == "end") i = 1;
+        else if (name == "step") i = 2;
+        else throw std::runtime_error("Unknown argument " + name + " for function range");
+
+        if (param_set[i]) {
+          throw std::runtime_error("Duplicate argument " + name + " for function range");
+        }
+        startEndStep[i] = value.get<int64_t>();
+        param_set[i] = true;
+    }
+    if (!param_set[1]) {
+      throw std::runtime_error("Missing required argument 'end' for function range");
+    }
+    int64_t start = param_set[0] ? startEndStep[0] : 0;
+    int64_t end = startEndStep[1];
+    int64_t step = param_set[2] ? startEndStep[2] : 1;
+
+    auto res = Value::array();
+    if (step > 0) {
+      for (int64_t i = start; i < end; i += step) {
+        res.push_back(Value(i));
+      }
+    } else {
+      for (int64_t i = start; i > end; i += step) {
+        res.push_back(Value(i));
+      }
+    }
+    return res;
+  }));
+
+  return std::make_shared<Context>(std::move(globals));
+}
+
+inline std::shared_ptr<Context> Context::make(Value && values, const std::shared_ptr<Context> & parent) {
+  return std::make_shared<Context>(values.is_null() ? Value::object() : std::move(values), parent);
+}
+
+}  // namespace minja

+ 15 - 13
examples/main/main.cpp

@@ -4,6 +4,7 @@
 #include "log.h"
 #include "sampling.h"
 #include "llama.h"
+#include "chat-template.hpp"
 
 #include <cstdio>
 #include <cstring>
@@ -84,14 +85,6 @@ static void sigint_handler(int signo) {
 }
 #endif
 
-static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
-    common_chat_msg new_msg{role, content};
-    auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
-    chat_msgs.push_back({role, content});
-    LOG_DBG("formatted: '%s'\n", formatted.c_str());
-    return formatted;
-}
-
 int main(int argc, char ** argv) {
     common_params params;
     g_params = &params;
@@ -165,6 +158,7 @@ int main(int argc, char ** argv) {
     }
 
     const llama_vocab * vocab = llama_model_get_vocab(model);
+    auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
 
     LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
 
@@ -207,7 +201,7 @@ int main(int argc, char ** argv) {
     }
 
     // auto enable conversation mode if chat template is available
-    const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
+    const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
     if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
         if (has_chat_template) {
             LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -225,7 +219,7 @@ int main(int argc, char ** argv) {
     // print chat template example in conversation mode
     if (params.conversation_mode) {
         if (params.enable_chat_template) {
-            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
+            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
         } else {
             LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
         }
@@ -269,10 +263,18 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd_inp;
 
+    auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
+        common_chat_msg new_msg{role, content};
+        auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
+        chat_msgs.push_back({role, content});
+        LOG_DBG("formatted: '%s'\n", formatted.c_str());
+        return formatted;
+    };
+
     {
         auto prompt = (params.conversation_mode && params.enable_chat_template)
             // format the system prompt in conversation mode (fallback to default if empty)
-            ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
+            ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
             // otherwise use the prompt as is
             : params.prompt;
         if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -779,7 +781,7 @@ int main(int argc, char ** argv) {
                     }
 
                     if (params.enable_chat_template) {
-                        chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+                        chat_add_and_format("assistant", assistant_ss.str());
                     }
                     is_interacting = true;
                     LOG("\n");
@@ -844,7 +846,7 @@ int main(int argc, char ** argv) {
 
                     bool format_chat = params.conversation_mode && params.enable_chat_template;
                     std::string user_inp = format_chat
-                        ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
+                        ? chat_add_and_format("user", std::move(buffer))
                         : std::move(buffer);
                     // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
                     const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

+ 33 - 9
examples/run/run.cpp

@@ -28,6 +28,7 @@
 #include "json.hpp"
 #include "linenoise.cpp/linenoise.h"
 #include "llama-cpp.h"
+#include "chat-template.hpp"
 
 #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
 [[noreturn]] static void sigint_handler(int) {
@@ -105,6 +106,7 @@ class Opt {
     llama_model_params   model_params;
     std::string model_;
     std::string          user;
+    bool                 use_jinja   = false;
     int                  context_size = -1, ngl = -1;
     float                temperature = -1;
     bool                 verbose     = false;
@@ -156,6 +158,8 @@ class Opt {
             } else if (options_parsing &&
                        (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
                 verbose = true;
+            } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
+                use_jinja = true;
             } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
                 help = true;
                 return 0;
@@ -713,13 +717,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
 }
 
 // Function to apply the chat template and resize `formatted` if needed
-static int apply_chat_template(LlamaData & llama_data, const bool append) {
+static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
+    if (use_jinja) {
+        json messages = json::array();
+        for (const auto & msg : llama_data.messages) {
+            messages.push_back({
+                {"role", msg.role},
+                {"content", msg.content},
+            });
+        }
+        try {
+            auto result = tmpl.apply(messages, /* tools= */ json(), append);
+            llama_data.fmtted.resize(result.size() + 1);
+            memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
+            return result.size();
+        } catch (const std::exception & e) {
+            printe("failed to render the chat template: %s\n", e.what());
+            return -1;
+        }
+    }
     int result = llama_chat_apply_template(
-        llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
+        tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
         append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
     if (append && result > static_cast<int>(llama_data.fmtted.size())) {
         llama_data.fmtted.resize(result);
-        result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
+        result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
                                            llama_data.messages.size(), append, llama_data.fmtted.data(),
                                            llama_data.fmtted.size());
     }
@@ -871,8 +893,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
 }
 
 // Helper function to apply the chat template and handle errors
-static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
-    const int new_len = apply_chat_template(llama_data, append);
+static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
+    const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
     if (new_len < 0) {
         printe("failed to apply the chat template\n");
         return -1;
@@ -931,9 +953,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
 }
 
 // Main chat loop function
-static int chat_loop(LlamaData & llama_data, const std::string & user) {
+static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
+    auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
+    GGML_ASSERT(chat_templates.template_default);
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
@@ -944,7 +968,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
 
         add_message("user", user.empty() ? user_input : user, llama_data);
         int new_len;
-        if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
+        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
             return 1;
         }
 
@@ -959,7 +983,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
         }
 
         add_message("assistant", response, llama_data);
-        if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
+        if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
             return 1;
         }
     }
@@ -1019,7 +1043,7 @@ int main(int argc, const char ** argv) {
         return 1;
     }
 
-    if (chat_loop(llama_data, opt.user)) {
+    if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
         return 1;
     }
 

+ 1 - 1
examples/server/README.md

@@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
 | `--grammar-file FNAME` | file to read grammar from |
 | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
-
+| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
 
 **Example-specific params**
 

+ 42 - 9
examples/server/server.cpp

@@ -1688,6 +1688,8 @@ struct server_context {
     // Necessary similarity of prompt for slot selection
     float slot_prompt_similarity = 0.0f;
 
+    common_chat_templates chat_templates;
+
     ~server_context() {
         // Clear any sampling context
         for (server_slot & slot : slots) {
@@ -1767,14 +1769,39 @@ struct server_context {
             cparams_dft.type_v = GGML_TYPE_F16;
         }
 
+        chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
+        GGML_ASSERT(chat_templates.template_default.get() != nullptr);
+
         return true;
     }
 
-    bool validate_builtin_chat_template() const {
+    bool validate_builtin_chat_template(bool use_jinja) const {
         llama_chat_message chat[] = {{"user", "test"}};
-        const char * tmpl = llama_model_chat_template(model);
-        const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
-        return chat_res > 0;
+
+        if (use_jinja) {
+            auto templates = common_chat_templates_from_model(model, "");
+            GGML_ASSERT(templates.template_default);
+            try {
+                templates.template_default->apply({{
+                    {"role", "user"},
+                    {"content", "test"},
+                }}, json(), true);
+                if (templates.template_tool_use) {
+                    templates.template_tool_use->apply({{
+                        {"role", "user"},
+                        {"content", "test"},
+                    }}, json(), true);
+                }
+                return true;
+            } catch (const std::exception & e) {
+                SRV_ERR("failed to apply template: %s\n", e.what());
+                return false;
+            }
+        } else {
+            const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
+            const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
+            return chat_res > 0;
+        }
     }
 
     void init() {
@@ -3659,9 +3686,12 @@ int main(int argc, char ** argv) {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model },
-            { "chat_template",               common_get_builtin_chat_template(ctx_server.model) },
+            { "chat_template",               ctx_server.chat_templates.template_default->source() },
             { "build_info",                  build_info },
         };
+        if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
+            data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
+        }
 
         res_ok(res, data);
     };
@@ -3889,7 +3919,10 @@ int main(int argc, char ** argv) {
             return;
         }
 
-        json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
+        auto body = json::parse(req.body);
+        const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
+        json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
+
         return handle_completions_impl(
             SERVER_TASK_TYPE_COMPLETION,
             data,
@@ -4299,7 +4332,7 @@ int main(int argc, char ** argv) {
 
     // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
     if (params.chat_template.empty()) {
-        if (!ctx_server.validate_builtin_chat_template()) {
+        if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
             LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
             params.chat_template = "chatml";
         }
@@ -4307,8 +4340,8 @@ int main(int argc, char ** argv) {
 
     // print sample chat example to make it clear which template is used
     LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
-        params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
-        common_chat_format_example(ctx_server.model, params.chat_template).c_str());
+        ctx_server.chat_templates.template_default->source().c_str(),
+        common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
 
     ctx_server.queue_tasks.on_new_task(std::bind(
                 &server_context::process_single_task, &ctx_server, std::placeholders::_1));

+ 10 - 6
examples/server/tests/unit/test_chat_completion.py

@@ -4,22 +4,26 @@ from utils import *
 
 server = ServerPreset.tinyllama2()
 
-
-@pytest.fixture(scope="module", autouse=True)
+@pytest.fixture(autouse=True)
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
 
 
 @pytest.mark.parametrize(
-    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
+    "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
     [
-        (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
-        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
+        (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
+        (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
+        (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
+        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
+        ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
     ]
 )
-def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
+def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
     global server
+    server.jinja = jinja
+    server.chat_template = chat_template
     server.start()
     res = server.make_request("POST", "/chat/completions", data={
         "model": model,

+ 6 - 1
examples/server/tests/utils.py

@@ -72,13 +72,14 @@ class ServerProcess:
     pooling: str | None = None
     draft: int | None = None
     api_key: str | None = None
-    response_format: str | None = None
     lora_files: List[str] | None = None
     disable_ctx_shift: int | None = False
     draft_min: int | None = None
     draft_max: int | None = None
     no_webui: bool | None = None
+    jinja: bool | None = None
     chat_template: str | None = None
+    chat_template_file: str | None = None
 
     # session variables
     process: subprocess.Popen | None = None
@@ -169,8 +170,12 @@ class ServerProcess:
             server_args.extend(["--draft-min", self.draft_min])
         if self.no_webui:
             server_args.append("--no-webui")
+        if self.jinja:
+            server_args.append("--jinja")
         if self.chat_template:
             server_args.extend(["--chat-template", self.chat_template])
+        if self.chat_template_file:
+            server_args.extend(["--chat-template-file", self.chat_template_file])
 
         args = [str(arg) for arg in [server_path, *server_args]]
         print(f"bench: starting server with: {' '.join(args)}")

+ 27 - 9
examples/server/utils.hpp

@@ -16,6 +16,8 @@
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
+#include "minja.hpp"
+#include "chat-template.hpp"
 
 #include <random>
 #include <sstream>
@@ -349,7 +351,7 @@ static llama_tokens format_infill(
 }
 
 // Format given chat. If tmpl is empty, we take the template from model metadata
-inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
+inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
     std::vector<common_chat_msg> chat;
 
     for (size_t i = 0; i < messages.size(); ++i) {
@@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
         chat.push_back({role, content});
     }
 
-    const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
+    const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
     LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
 
     return formatted_chat;
@@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) {
     return llama_params;
 }
 
-static json oaicompat_chat_completion_params_parse(
-        const struct llama_model * model,
-        const json & body, /* openai api json semantics */
-        const std::string & chat_template) {
+static json oaicompat_completion_params_parse(
+    const json & body, /* openai api json semantics */
+    const common_chat_template & tmpl,
+    bool use_jinja)
+{
     json llama_params;
 
-    // Apply chat template to the list of messages
-    llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
+    auto tools = json_value(body, "tools", json());
+    auto has_tools = tools.is_array() && !tools.empty();
+
+    if (has_tools) {
+        if (use_jinja) {
+            LOG_WRN("tools param is not fully supported yet\n");
+        } else {
+            throw std::runtime_error("tools param requires --jinja flag");
+        }
+    }
 
     // Handle "stop" field
     if (body.contains("stop") && body.at("stop").is_string()) {
@@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse(
         }
     }
 
+    // Apply chat template to the list of messages
+    if (use_jinja) {
+        llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
+    } else {
+        llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
+    }
+
     // Handle "n" field
     int n_choices = json_value(body, "n", 1);
     if (n_choices != 1) {
@@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse(
     }
 
     // Params supported by OAI but unsupported by llama.cpp
-    static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
+    static const std::vector<std::string> unsupported_params { "tool_choice" };
     for (const auto & param : unsupported_params) {
         if (body.contains(param)) {
             throw std::runtime_error("Unsupported param: " + param);

+ 1 - 1
examples/simple-chat/simple-chat.cpp

@@ -163,7 +163,7 @@ int main(int argc, char ** argv) {
             break;
         }
 
-        const char * tmpl = llama_model_chat_template(model);
+        const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
 
         // add the user input to the message list and format it
         messages.push_back({"user", strdup(user.c_str())});

+ 2 - 1
include/llama.h

@@ -510,7 +510,8 @@ extern "C" {
     LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
 
     // Get the default chat template. Returns nullptr if not available
-    LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
+    // If name is NULL, returns the default chat template
+    LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
 
     // Returns the total number of parameters in the model
     LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);

+ 77 - 0
scripts/get_hf_chat_template.py

@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+'''
+  Fetches the Jinja chat template of a HuggingFace model.
+  If a model has multiple chat templates, you can specify the variant name.
+
+  Syntax:
+    ./scripts/get_hf_chat_template.py model_id [variant]
+
+  Examples:
+    ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
+    ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
+    ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
+'''
+
+import json
+import re
+import sys
+
+
+def get_hf_chat_template(model_id, variant=None):
+    try:
+        # Use huggingface_hub library if available.
+        # Allows access to gated models if the user has access and ran `huggingface-cli login`.
+        from huggingface_hub import hf_hub_download
+        with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
+            config_str = f.read()
+    except ImportError:
+        import requests
+        assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
+        response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
+        if response.status_code == 401:
+            raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
+        response.raise_for_status()
+        config_str = response.text
+
+    try:
+        config = json.loads(config_str)
+    except json.JSONDecodeError:
+        # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
+        # (Remove extra '}' near the end of the file)
+        config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
+
+    chat_template = config['chat_template']
+    if isinstance(chat_template, str):
+        return chat_template
+    else:
+        variants = {
+            ct['name']: ct['template']
+            for ct in chat_template
+        }
+
+        def format_variants():
+            return ', '.join(f'"{v}"' for v in variants.keys())
+
+        if variant is None:
+            if 'default' not in variants:
+                raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
+            variant = 'default'
+            sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
+        elif variant not in variants:
+            raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
+
+        return variants[variant]
+
+
+def main(args):
+    if len(args) < 1:
+        raise ValueError("Please provide a model ID and an optional variant name")
+    model_id = args[0]
+    variant = None if len(args) < 2 else args[1]
+
+    template = get_hf_chat_template(model_id, variant)
+    sys.stdout.write(template)
+
+
+if __name__ == '__main__':
+    main(sys.argv[1:])

+ 1 - 1
src/CMakeLists.txt

@@ -29,7 +29,7 @@ add_library(llama
             unicode-data.cpp
             )
 
-target_include_directories(llama PUBLIC . ../include)
+target_include_directories(llama PUBLIC . ../include ../common)
 target_compile_features   (llama PUBLIC cxx_std_17) # don't bump
 
 target_link_libraries(llama PUBLIC ggml)

+ 4 - 2
src/llama-arch.cpp

@@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_HF_JSON,              "tokenizer.huggingface.json"              },
     { LLM_KV_TOKENIZER_RWKV,                 "tokenizer.rwkv.world"                    },
     { LLM_KV_TOKENIZER_CHAT_TEMPLATE,        "tokenizer.chat_template"                 },
+    { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,      "tokenizer.chat_template.%s"              },
     { LLM_KV_TOKENIZER_FIM_PRE_ID,           "tokenizer.ggml.fim_pre_token_id"         },
     { LLM_KV_TOKENIZER_FIM_SUF_ID,           "tokenizer.ggml.fim_suf_token_id"         },
     { LLM_KV_TOKENIZER_FIM_MID_ID,           "tokenizer.ggml.fim_mid_token_id"         },
@@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_CONVNEXT_GAMMA,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 };
 
-LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
+LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
 
 std::string LLM_KV::operator()(llm_kv kv) const {
-    return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
+    return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
+        : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
 }
 
 std::string LLM_TN_IMPL::str() const {

+ 3 - 1
src/llama-arch.h

@@ -177,6 +177,7 @@ enum llm_kv {
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_RWKV,
     LLM_KV_TOKENIZER_CHAT_TEMPLATE,
+    LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
     LLM_KV_TOKENIZER_FIM_PRE_ID,
     LLM_KV_TOKENIZER_FIM_SUF_ID,
     LLM_KV_TOKENIZER_FIM_MID_ID,
@@ -335,9 +336,10 @@ enum llm_tensor_layer {
 };
 
 struct LLM_KV {
-    LLM_KV(llm_arch arch);
+    LLM_KV(llm_arch arch, const char * suffix = nullptr);
 
     llm_arch arch;
+    const char * suffix;
 
     std::string operator()(llm_kv kv) const;
 };

+ 4 - 2
src/llama-model.cpp

@@ -3955,8 +3955,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
     return model->size();
 }
 
-const char * llama_model_chat_template(const struct llama_model * model) {
-    const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
+const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
+    const auto & it = model->gguf_kv.find(key);
     if (it == model->gguf_kv.end()) {
         return nullptr;
     }

Разница между файлами не показана из-за своего большого размера
+ 75 - 12
tests/test-chat-template.cpp


Некоторые файлы не были показаны из-за большого количества измененных файлов