Jelajahi Sumber

`server`: streaming of tool calls and thoughts when `--jinja` is on (#12379)

* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
Olivier Chafik 7 bulan lalu
induk
melakukan
f5cd27b71d

+ 4 - 0
common/CMakeLists.txt

@@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
     base64.hpp
     chat.cpp
     chat.h
+    chat-parser.cpp
+    chat-parser.h
     common.cpp
     common.h
     console.cpp
     console.h
     json-schema-to-grammar.cpp
     json.hpp
+    json-partial.h
+    json-partial.cpp
     llguidance.cpp
     log.cpp
     log.h

+ 376 - 0
common/chat-parser.cpp

@@ -0,0 +1,376 @@
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+#include <optional>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+using json = nlohmann::ordered_json;
+
+common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
+    : input_(input), is_partial_(is_partial), syntax_(syntax)
+{
+    result_.role = "assistant";
+
+    while (true) {
+        std::string id = std::to_string(std::rand());
+        if (input.find(id) == std::string::npos) {
+            healing_marker_ = id;
+            break;
+        }
+    }
+}
+
+std::string common_chat_msg_parser::str(const common_string_range & rng) const {
+    GGML_ASSERT(rng.begin <= rng.end);
+    return input_.substr(rng.begin, rng.end - rng.begin);
+}
+
+void common_chat_msg_parser::add_content(const std::string &content) {
+    result_.content += content;
+}
+
+void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
+    result_.reasoning_content += reasoning_content;
+}
+
+bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
+    if (name.empty()) {
+        return false;
+    }
+
+    common_chat_tool_call tool_call;
+    tool_call.name = name;
+    tool_call.arguments = arguments;
+    tool_call.id = id;
+
+    // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
+    result_.tool_calls.emplace_back(tool_call);
+    return true;
+}
+bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
+    std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
+    std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
+    std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
+    return add_tool_call(name, id, arguments);
+}
+
+bool common_chat_msg_parser::add_tool_calls(const json & arr) {
+    for (const auto & item : arr) {
+        if (!add_tool_call(item)) {
+            return false;
+        }
+    }
+    return true;
+}
+void common_chat_msg_parser::finish() {
+    if (!is_partial_ && pos_ != input_.size()) {
+        throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
+    }
+}
+
+bool common_chat_msg_parser::consume_spaces() {
+    const auto length = input_.size();
+    auto consumed = false;
+    while (pos_ < length && std::isspace(input_[pos_])) {
+        ++pos_;
+        consumed = true;
+    }
+    return consumed;
+}
+
+bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
+    auto pos = pos_;
+    for (auto i = 0u; i < literal.size(); ++i) {
+        if (pos >= input_.size()) {
+            return false;
+        }
+        if (input_[pos] != literal[i]) {
+            return false;
+        }
+        ++pos;
+    }
+    pos_ = pos;
+    return true;
+}
+
+std::optional<common_chat_msg_parser::find_regex_result>  common_chat_msg_parser::try_find_literal(const std::string & literal) {
+    auto idx = input_.find(literal, pos_);
+    if (idx != std::string::npos) {
+        find_regex_result res;
+        res.prelude = input_.substr(pos_, idx - pos_);
+        auto end = idx + literal.size();
+        res.groups.emplace_back(common_string_range{idx, end});
+        move_to(end);
+        return res;
+    }
+    if (is_partial_) {
+        idx = string_find_partial_stop(input_, literal);
+        if (idx != std::string::npos && idx >= pos_) {
+            find_regex_result res;
+            res.prelude = input_.substr(pos_, idx - pos_);
+            auto end = input_.size();
+            res.groups.emplace_back(common_string_range{idx, end});
+            move_to(end);
+            return res;
+        }
+    }
+    return std::nullopt;
+}
+
+void common_chat_msg_parser::consume_literal(const std::string & literal) {
+    if (!try_consume_literal(literal)) {
+        throw common_chat_msg_partial_exception(literal);
+    }
+}
+
+bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
+    auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
+        auto stripped_reasoning = string_strip(reasoning);
+        if (stripped_reasoning.empty()) {
+            return;
+        }
+        if (syntax_.reasoning_in_content) {
+            add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
+            add_content(stripped_reasoning);
+            if (closed) {
+                add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
+            }
+        } else {
+            add_reasoning_content(stripped_reasoning);
+        }
+    };
+    if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
+        if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
+            if (auto res = try_find_literal(end_think)) {
+                handle_reasoning(res->prelude, /* closed */ true);
+                consume_spaces();
+                return true;
+            }
+            auto rest = consume_rest();
+            if (!rest.empty()) {
+                handle_reasoning(rest, /* closed */ !is_partial());
+            }
+            if (!syntax_.thinking_forced_open) {
+                throw common_chat_msg_partial_exception(end_think);
+            }
+            return true;
+        }
+    }
+    return false;
+}
+
+std::string common_chat_msg_parser::consume_rest() {
+    auto rest = input_.substr(pos_);
+    pos_ = input_.size();
+    return rest;
+}
+
+// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
+std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
+    auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
+    if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
+        return std::nullopt;
+    }
+    if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
+        if (is_partial()) {
+            throw common_chat_msg_partial_exception(regex.str());
+        }
+        return std::nullopt;
+    }
+    auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
+    pos_ = m.groups[0].end;
+
+    return find_regex_result{prelude, m.groups};
+}
+
+common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
+    if (auto result = try_consume_regex(regex)) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception(regex.str());
+}
+
+std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
+    auto m = regex.search(input_, pos_);
+    if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
+        return std::nullopt;
+    }
+    if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
+        if (is_partial()) {
+            throw common_chat_msg_partial_exception(regex.str());
+        }
+        return std::nullopt;
+    }
+    if (m.groups[0].begin != pos_) {
+        // Didn't match at the current position.
+        return std::nullopt;
+    }
+    pos_ = m.groups[0].end;
+
+    return find_regex_result {
+        /* .prelude = */ "",
+        m.groups,
+    };
+}
+
+std::optional<common_json> common_chat_msg_parser::try_consume_json() {
+    auto it = input_.cbegin() + pos_;
+    const auto end = input_.cend();
+    common_json result;
+    if (!common_json_parse(it, end, healing_marker_, result)) {
+        return std::nullopt;
+    }
+    pos_ = std::distance(input_.cbegin(), it);
+    if (result.healing_marker.marker.empty()) {
+        // No healing marker, just return the parsed json
+        return result;
+    }
+    if (!is_partial()) {
+        throw common_chat_msg_partial_exception("JSON");
+    }
+    return result;
+}
+
+common_json common_chat_msg_parser::consume_json() {
+    if (auto result = try_consume_json()) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception("JSON");
+}
+
+common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
+    const std::vector<std::vector<std::string>> & args_paths,
+    const std::vector<std::vector<std::string>> & content_paths
+) {
+    if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
+        return *result;
+    }
+    throw common_chat_msg_partial_exception("JSON");
+}
+
+std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
+    const std::vector<std::vector<std::string>> & args_paths,
+    const std::vector<std::vector<std::string>> & content_paths
+) {
+    auto partial = try_consume_json();
+    if (!partial) {
+        return std::nullopt;
+    }
+    auto is_arguments_path = [&](const std::vector<std::string> & path) {
+        return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
+    };
+    auto is_content_path = [&](const std::vector<std::string> & path) {
+        return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
+    };
+
+    if (partial->healing_marker.marker.empty()) {
+        if (args_paths.empty()) {
+            // No arguments to dump, and JSON was parsed fully.
+            return consume_json_result {
+                partial->json,
+                /* .is_partial = */ false,
+            };
+        }
+        if (is_arguments_path({})) {
+            // Entire JSON is the arguments and was parsed fully.
+            return consume_json_result {
+                partial->json.dump(),
+                /* .is_partial = */ false,
+            };
+        }
+    }
+
+    LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
+
+    auto found_healing_marker = false;
+    std::vector<std::string> path;
+    std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
+        if (is_arguments_path(path)) {
+            auto arguments = j.dump();
+            if (is_partial() && !partial->healing_marker.marker.empty()) {
+                auto idx = arguments.find(partial->healing_marker.json_dump_marker);
+                if (idx != std::string::npos) {
+                    arguments.resize(idx);
+                    found_healing_marker = true;
+                }
+                if (arguments == "\"") {
+                    // This happens because of completing `:"$magic` after `"arguments"`
+                    arguments = "";
+                }
+            }
+            return arguments;
+        }
+        if (is_content_path(path)) {
+            if (!j.is_string()) {
+                throw std::runtime_error("Content path must be a string");
+            }
+            std::string str = j;
+            auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
+            if (idx != std::string::npos) {
+                str.resize(idx);
+                found_healing_marker = true;
+            }
+            return str;
+        }
+        if (j.is_object()) {
+            auto obj = json::object();
+            for (const auto & p : j.items()) {
+                const auto & key = p.key();
+                const auto & value = p.value();
+                const std::string key_str = key; // NOLINT
+                auto idx = key_str.find(healing_marker_);
+                if (idx != std::string::npos) {
+                    found_healing_marker = true;
+                    break;
+                }
+                path.push_back(key_str);
+                if (value.is_string()) {
+                    const std::string value_str = value;
+                    if (value_str.find(healing_marker_) != std::string::npos) {
+                        found_healing_marker = true;
+                        if (is_content_path(path)) {
+                            if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
+                                // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
+                                obj[key] = remove_unsupported_healings_and_dump_args(value);
+                            }
+                        }
+                        break;
+                    }
+                    obj[key] = value;
+                } else {
+                    obj[key] = remove_unsupported_healings_and_dump_args(value);
+                }
+                path.pop_back();
+            }
+            return obj;
+        }
+        if (j.is_array()) {
+            auto arr = json::array();
+            for (const auto & value : j) {
+                if (value.is_string()) {
+                    std::string str = value;
+                    auto idx = str.find(healing_marker_);
+                    if (idx != std::string::npos) {
+                        // Don't heal array values that aren't in the arguments.
+                        found_healing_marker = true;
+                        break;
+                    }
+                }
+                arr.push_back(remove_unsupported_healings_and_dump_args(value));
+            }
+            return arr;
+        }
+        return j;
+    };
+
+    auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
+    LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
+    return consume_json_result {
+        cleaned,
+        /* .is_partial = */ found_healing_marker,
+    };
+}

+ 116 - 0
common/chat-parser.h

@@ -0,0 +1,116 @@
+#pragma once
+
+#include "chat.h"
+#include "json-partial.h"
+#include "json.hpp"
+#include "regex-partial.h"
+
+#include <optional>
+#include <string>
+#include <vector>
+
+class common_chat_msg_partial_exception : public std::runtime_error {
+  public:
+    common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
+};
+
+class common_chat_msg_parser {
+    std::string input_;
+    bool is_partial_;
+    common_chat_syntax syntax_;
+    std::string healing_marker_;
+
+    size_t pos_ = 0;
+    common_chat_msg result_;
+
+  public:
+    common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
+    const std::string & input() const { return input_; }
+    size_t pos() const { return pos_; }
+    const std::string & healing_marker() const { return healing_marker_; }
+    const bool & is_partial() const { return is_partial_; }
+    const common_chat_msg & result() const { return result_; }
+
+    void move_to(size_t pos) {
+        if (pos > input_.size()) {
+            throw std::runtime_error("Invalid position!");
+        }
+        pos_ = pos;
+    }
+    void move_back(size_t n) {
+        if (pos_ < n) {
+            throw std::runtime_error("Can't move back that far!");
+        }
+        pos_ -= n;
+    }
+
+    // Get the substring of the input at the given range
+    std::string str(const common_string_range & rng) const;
+
+    // Appends to the result.content field
+    void add_content(const std::string & content);
+
+    // Appends to the result.reasoning_content field
+    void add_reasoning_content(const std::string & reasoning_content);
+
+    // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
+    bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
+
+    // Adds a tool call using the "name", "id" and "arguments" fields of the json object
+    bool add_tool_call(const nlohmann::ordered_json & tool_call);
+
+    // Adds an array of tool calls using their "name", "id" and "arguments" fields.
+    bool add_tool_calls(const nlohmann::ordered_json & arr);
+
+    void finish();
+
+    bool consume_spaces();
+
+    void consume_literal(const std::string & literal);
+
+    bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
+
+    std::string consume_rest();
+
+    struct find_regex_result {
+        std::string prelude;
+        std::vector<common_string_range> groups;
+    };
+
+    std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
+
+    bool try_consume_literal(const std::string & literal);
+
+    std::optional<find_regex_result> try_find_literal(const std::string & literal);
+
+    find_regex_result consume_regex(const common_regex & regex);
+
+    std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
+
+    std::optional<common_json> try_consume_json();
+    common_json consume_json();
+
+    struct consume_json_result {
+        nlohmann::ordered_json value;
+        bool is_partial;
+    };
+
+    /*
+        Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
+
+        By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
+        e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
+
+        But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
+        - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
+        - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
+    */
+    consume_json_result consume_json_with_dumped_args(
+        const std::vector<std::vector<std::string>> & args_paths = {},
+        const std::vector<std::vector<std::string>> & content_paths = {}
+    );
+    std::optional<consume_json_result> try_consume_json_with_dumped_args(
+        const std::vector<std::vector<std::string>> & args_paths = {},
+        const std::vector<std::vector<std::string>> & content_paths = {}
+    );
+};

File diff ditekan karena terlalu besar
+ 472 - 384
common/chat.cpp


+ 67 - 5
common/chat.h

@@ -3,6 +3,7 @@
 #pragma once
 
 #include "common.h"
+#include <functional>
 #include <chrono>
 #include <string>
 #include <vector>
@@ -13,11 +14,19 @@ struct common_chat_tool_call {
     std::string name;
     std::string arguments;
     std::string id;
+
+    bool operator==(const common_chat_tool_call & other) const {
+        return name == other.name && arguments == other.arguments && id == other.id;
+    }
 };
 
 struct common_chat_msg_content_part {
     std::string type;
     std::string text;
+
+    bool operator==(const common_chat_msg_content_part & other) const {
+        return type == other.type && text == other.text;
+    }
 };
 
 struct common_chat_msg {
@@ -28,6 +37,51 @@ struct common_chat_msg {
     std::string reasoning_content;
     std::string tool_name;
     std::string tool_call_id;
+
+    template <class T> T to_json_oaicompat() const;
+
+    bool empty() const {
+        return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
+    }
+    void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
+        for (auto i = 0u; i < tool_calls.size(); i++) {
+            if (ids_cache.size() <= i) {
+                auto id = tool_calls[i].id;
+                if (id.empty()) {
+                    id = gen_tool_call_id();
+                }
+                ids_cache.push_back(id);
+            }
+            tool_calls[i].id = ids_cache[i];
+        }
+    }
+    bool operator==(const common_chat_msg & other) const {
+        return role == other.role
+            && content == other.content
+            && content_parts == other.content_parts
+            && tool_calls == other.tool_calls
+            && reasoning_content == other.reasoning_content
+            && tool_name == other.tool_name
+            && tool_call_id == other.tool_call_id;
+    }
+    bool operator!=(const common_chat_msg & other) const {
+        return !(*this == other);
+    }
+};
+
+struct common_chat_msg_diff {
+    // std::string reasoning_content_delta;
+    std::string content_delta;
+    size_t tool_call_index = std::string::npos;
+    common_chat_tool_call tool_call_delta;
+
+    static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
+
+    bool operator==(const common_chat_msg_diff & other) const {
+        return content_delta == other.content_delta
+        && tool_call_index == other.tool_call_index
+        && tool_call_delta == other.tool_call_delta;
+    }
 };
 
 struct common_chat_tool {
@@ -49,14 +103,11 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_LLAMA_3_X,
     COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
     COMMON_CHAT_FORMAT_DEEPSEEK_R1,
-    COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
     COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
     COMMON_CHAT_FORMAT_HERMES_2_PRO,
-    COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
     COMMON_CHAT_FORMAT_COMMAND_R7B,
-    COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
 
     COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
 };
@@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
     std::vector<common_chat_tool> tools;
     common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
     bool parallel_tool_calls = false;
-    bool extract_reasoning     = true;
+    common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
     std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
 };
 
@@ -80,11 +131,20 @@ struct common_chat_params {
     std::string                         prompt;
     std::string                         grammar;
     bool                                grammar_lazy = false;
+    bool                                thinking_forced_open = false;
     std::vector<common_grammar_trigger> grammar_triggers;
     std::vector<std::string>            preserved_tokens;
     std::vector<std::string>            additional_stops;
 };
 
+struct common_chat_syntax {
+    common_chat_format       format                = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    common_reasoning_format  reasoning_format      = COMMON_REASONING_FORMAT_NONE;
+    // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
+    bool                     reasoning_in_content  = false;
+    bool                     thinking_forced_open  = false;
+};
+
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
 
@@ -122,7 +182,7 @@ std::string common_chat_format_example(
     bool use_jinja);
 
 std::string               common_chat_format_name(common_chat_format format);
-common_chat_msg           common_chat_parse(      const std::string & input, common_chat_format format);
+common_chat_msg           common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
 
 common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
 
@@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
 // T can be std::string containing JSON or nlohmann::ordered_json
 template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
 template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
+
+template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

+ 1 - 1
common/common.h

@@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
     COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
     COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
     COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
-    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
+    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
 };
 
 struct common_grammar_trigger {

+ 255 - 0
common/json-partial.cpp

@@ -0,0 +1,255 @@
+#include <json-partial.h>
+#include "ggml.h"
+#include "log.h"
+#include <string>
+
+#include <json.hpp>
+
+using json = nlohmann::ordered_json;
+
+enum common_json_stack_element_type {
+    COMMON_JSON_STACK_ELEMENT_OBJECT,
+    COMMON_JSON_STACK_ELEMENT_KEY,
+    COMMON_JSON_STACK_ELEMENT_ARRAY,
+};
+
+struct common_json_stack_element {
+    common_json_stack_element_type type;
+    std::string key;
+};
+
+bool common_json_parse(
+    const std::string & input,
+    const std::string & healing_marker,
+    common_json & out)
+{
+    std::string::const_iterator it = input.begin();
+    const auto end = input.end();
+    return common_json_parse(it, end, healing_marker, out);
+}
+
+bool common_json_parse(
+    std::string::const_iterator & it,
+    const std::string::const_iterator & end,
+    const std::string & healing_marker,
+    common_json & out)
+{
+    // // https://json.nlohmann.me/features/parsing/sax_interface/
+    struct json_error_locator : public nlohmann::json_sax<json> {
+        std::size_t position;
+        bool found_error;
+        std::string last_token;
+        std::string exception_message;
+        std::vector<common_json_stack_element> stack;
+
+        json_error_locator() : position(0), found_error(false) {}
+
+        bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
+            this->position = position - 1;
+            this->found_error = true;
+            this->last_token = last_token;
+            this->exception_message = ex.what();
+            return false;
+        }
+        void close_value() {
+            if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
+                stack.pop_back();
+            }
+        }
+        bool null() override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool boolean(bool) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_integer(number_integer_t) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_unsigned(number_unsigned_t) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool number_float(number_float_t, const string_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool string(string_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool binary(binary_t &) override { // NOLINT
+            close_value();
+            return true;
+        }
+        bool start_object(std::size_t) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
+            return true;
+        }
+        bool end_object() override {
+            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
+            stack.pop_back();
+            close_value();
+            return true;
+        }
+        bool key(string_t & key) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
+            return true;
+        }
+        bool start_array(std::size_t) override { // NOLINT
+            stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
+            return true;
+        }
+        bool end_array() override {
+            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
+            stack.pop_back();
+            close_value();
+            return true;
+        }
+    };
+    json_error_locator err_loc;
+    auto start = it;
+    json::sax_parse(it, end, &err_loc);
+
+    if (err_loc.found_error) {
+        it = start;
+        auto temptative_end = it + err_loc.position;
+        // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
+
+        auto input = std::string(it, temptative_end);
+        try {
+            out.json = json::parse(input);
+            // out.json = json::parse(it, temptative_end);
+            it = temptative_end;
+            return true;
+        } catch (const std::exception & ex) {
+            // No, needs healing.
+            LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
+        }
+        auto can_parse = [](const std::string & str) {
+            try {
+                auto _ = json::parse(str); // NOLINT
+                return true;
+            } catch (const std::exception &) {
+                return false;
+            }
+        };
+        if (!healing_marker.empty() && !err_loc.stack.empty()) {
+            std::string str(it, temptative_end);
+            auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
+            if (last_non_sp_pos == std::string::npos) {
+                throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
+            }
+            auto last_non_sp_char = str[last_non_sp_pos];
+            // Used to detect stops on a number, which may not be complete.
+            auto was_maybe_number = [&]() {
+                if (!str.empty() && std::isspace(str.back())) {
+                    return false;
+                }
+                return std::isdigit(last_non_sp_char) ||
+                    last_non_sp_char == '.' ||
+                    last_non_sp_char == 'e' ||
+                    last_non_sp_char == 'E' ||
+                    last_non_sp_char == '-';
+            };
+
+            std::string closing;
+            for (size_t i = err_loc.stack.size(); i > 0; i--) {
+                auto & el = err_loc.stack[i - 1];
+                if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
+                    closing += "}";
+                } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
+                    closing += "]";
+                } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
+                    throw std::runtime_error("Unexpected stack element type");
+                }
+            }
+
+            const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
+
+            if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
+                // We're inside an object value
+                if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
+                    // Was about to create an object value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                } else if (can_parse(str + ": 1" + closing)) {
+                    str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
+                } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
+                    // Was about to create an object
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
+                } else if (can_parse(str + "\"" + closing)) {
+                    // Was inside an object value string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
+                    // Was inside an object value string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
+                } else {
+                    // find last :
+                    auto last_pos = str.find_last_of(':');
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
+                    }
+                    // Cutting back to opening : for object value
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
+                if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
+                    // Was about to create an array value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                } else if (can_parse(str + "\"" + closing)) {
+                    // Was inside an array value string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
+                    // Was inside an array value string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
+                } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
+                    // Had just finished a value
+                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
+                } else {
+                    auto last_pos = str.find_last_of("[,");
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
+                    }
+                    // Cutting back to last [ or , for array value
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
+                if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
+                        (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
+                    // Was about to create an object key+value
+                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
+                } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
+                    // Was about to create an object key+value
+                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
+                } else if (can_parse(str + "\": 1" + closing)) {
+                    // Was inside an object key string
+                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
+                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
+                    // Was inside an object key string after an escape
+                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
+                } else {
+                    auto last_pos = str.find_last_of(':');
+                    if (last_pos == std::string::npos) {
+                        throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
+                    }
+                    // fprintf(stderr, "Cutting back to last : for object key+value\n");
+                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
+                }
+            } else {
+                throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
+            }
+            // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
+            out.json = json::parse(str);
+            it = temptative_end;
+            return true;
+        }
+        // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
+        // fprintf(stderr, "Closing: TODO\n");
+        return false;
+    }
+    out.json = json::parse(it, end);
+    it = end;
+    return true;
+}

+ 37 - 0
common/json-partial.h

@@ -0,0 +1,37 @@
+#pragma once
+#include <json.hpp>
+
+// Healing marker (empty if the JSON was fully parsed / wasn't healed).
+struct common_healing_marker {
+    // Raw marker.
+    std::string marker;
+
+    // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
+    std::string json_dump_marker;
+};
+
+// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
+struct common_json {
+    nlohmann::ordered_json json;
+
+    common_healing_marker healing_marker;
+};
+
+// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
+//
+// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
+// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
+// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
+//
+// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
+bool common_json_parse(
+    const std::string & input,
+    const std::string & healing_marker,
+    common_json & out);
+
+// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
+bool common_json_parse(
+    std::string::const_iterator & it,
+    const std::string::const_iterator & end,
+    const std::string & healing_marker,
+    common_json & out);

+ 7 - 8
common/sampling.cpp

@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
 #endif // LLAMA_USE_LLGUIDANCE
     } else {
-        std::vector<std::string> patterns_at_start;
+        std::vector<std::string> trigger_patterns;
         std::vector<std::string> patterns_anywhere;
         std::vector<llama_token> trigger_tokens;
         for (const auto & trigger : params.grammar_triggers) {
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
-                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
                 {
-                    const auto & pattern = trigger.value;
-                    (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
+                    patterns_anywhere.push_back(trigger.value);
+                    break;
+                }
+                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
+                {
+                    trigger_patterns.push_back(trigger.value);
                     break;
                 }
                 case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
             }
         }
 
-        std::vector<std::string> trigger_patterns;
-        if (!patterns_at_start.empty()) {
-            trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
-        }
         if (!patterns_anywhere.empty()) {
             trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
         }

+ 53 - 24
docs/function-calling.md

@@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
 > [!TIP]
 > If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
 
+> [!CAUTION]
+> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
+
 Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
 
 ```bash
 curl http://localhost:8080/v1/chat/completions -d '{
-"model": "gpt-3.5-turbo",
-"tools": [
-    {
-    "type":"function",
-    "function":{
-        "name":"python",
-        "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
-        "parameters":{
-        "type":"object",
-        "properties":{
-            "code":{
-            "type":"string",
-            "description":"The code to run in the ipython interpreter."
+    "model": "gpt-3.5-turbo",
+    "tools": [
+        {
+        "type":"function",
+        "function":{
+            "name":"python",
+            "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
+            "parameters":{
+            "type":"object",
+            "properties":{
+                "code":{
+                "type":"string",
+                "description":"The code to run in the ipython interpreter."
+                }
+            },
+            "required":["code"]
             }
-        },
-        "required":["code"]
         }
-    }
-    }
-],
-"messages": [
-    {
-    "role": "user",
-    "content": "Print a hello world message with python."
-    }
-]
+        }
+    ],
+    "messages": [
+        {
+        "role": "user",
+        "content": "Print a hello world message with python."
+        }
+    ]
+}'
+
+
+curl http://localhost:8080/v1/chat/completions -d '{
+    "model": "gpt-3.5-turbo",
+    "messages": [
+        {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
+        {"role": "user", "content": "What is the weather in Istanbul?"}
+    ],
+    "tools": [{
+        "type":"function",
+        "function":{
+            "name":"get_current_weather",
+            "description":"Get the current weather in a given location",
+            "parameters":{
+                "type":"object",
+                "properties":{
+                    "location":{
+                        "type":"string",
+                        "description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
+                    }
+                },
+                "required":["location"]
+            }
+        }
+    }]
 }'
 ```
 

+ 62 - 0
models/templates/Qwen-QwQ-32B.jinja

@@ -0,0 +1,62 @@
+{%- if tools %}
+    {{- '<|im_start|>system\n' }}
+    {%- if messages[0]['role'] == 'system' %}
+        {{- messages[0]['content'] }}
+    {%- else %}
+        {{- '' }}
+    {%- endif %}
+    {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
+    {%- for tool in tools %}
+        {{- "\n" }}
+        {{- tool | tojson }}
+    {%- endfor %}
+    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
+{%- else %}
+    {%- if messages[0]['role'] == 'system' %}
+        {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
+  {%- endif %}
+{%- endif %}
+{%- for message in messages %}
+    {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+        {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
+    {%- elif message.role == "assistant" and not message.tool_calls %}
+        {%- set content = message.content %}
+        {%- if not loop.last %}
+            {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
+        {%- endif %}
+        {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
+    {%- elif message.role == "assistant" %}
+        {%- set content = message.content %}
+        {%- if not loop.last %}
+            {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
+        {%- endif %}
+        {{- '<|im_start|>' + message.role }}
+        {%- if message.content %}
+            {{- '\n' + content }}
+        {%- endif %}
+        {%- for tool_call in message.tool_calls %}
+            {%- if tool_call.function is defined %}
+                {%- set tool_call = tool_call.function %}
+            {%- endif %}
+            {{- '\n<tool_call>\n{"name": "' }}
+            {{- tool_call.name }}
+            {{- '", "arguments": ' }}
+            {{- tool_call.arguments | tojson }}
+            {{- '}\n</tool_call>' }}
+        {%- endfor %}
+        {{- '<|im_end|>\n' }}
+    {%- elif message.role == "tool" %}
+        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
+            {{- '<|im_start|>user' }}
+        {%- endif %}
+        {{- '\n<tool_response>\n' }}
+        {{- message.content }}
+        {{- '\n</tool_response>' }}
+        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+            {{- '<|im_end|>\n' }}
+        {%- endif %}
+    {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+    {{- '<|im_start|>assistant\n<think>\n' }}
+{%- endif %}

+ 1 - 0
models/templates/README.md

@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
 ./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
 ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use   > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
 ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct                      > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
+./scripts/get_chat_template.py Qwen/QwQ-32B                                  > models/templates/Qwen-QwQ-32B.jinja
 ```

+ 11 - 0
scripts/tool_bench.py

@@ -12,6 +12,7 @@
         export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
         export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
 
+        ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
         ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M"      --output qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF      --ollama qwen2.5:1.5b-instruct-q4_K_M
         ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M"  --output qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF  --ollama qwen2.5-coder:7b
 
@@ -205,6 +206,7 @@ def run(
     model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
     hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
     chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
+    chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
     ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
     llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
     n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
@@ -229,6 +231,12 @@ def run(
     # n_ctx = 8192
     n_ctx = 2048
 
+    if model is None:
+        if hf is not None:
+            model = hf.split("/")[-1]
+        elif ollama is not None:
+            model = ollama
+
     assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
 
     with output.open('a' if append else 'w') as output_file:
@@ -320,6 +328,7 @@ def run(
                     server.model_hf_repo = hf
                     server.model_hf_file = None
                     server.chat_template = chat_template
+                    server.chat_template_file = chat_template_file
                     server.server_path = server_path
                     if port is not None:
                         server.server_port = port
@@ -335,6 +344,7 @@ def run(
                                 temp=t,
                                 output_kwargs=dict(
                                     chat_template=chat_template,
+                                    chat_template_file=chat_template_file,
                                 ),
                                 request_kwargs=dict(
                                     ignore_chat_grammar=ignore_chat_grammar,
@@ -355,6 +365,7 @@ def run(
                         temp=t,
                         output_kwargs=dict(
                             chat_template=None,
+                            chat_template_file=None,
                         ),
                         request_kwargs=dict(
                             model=ollama,

+ 12 - 2
src/llama-grammar.cpp

@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             for (const auto & trigger_pattern : grammar.trigger_patterns) {
                 if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
                     grammar.awaiting_trigger = false;
-                    // get from the first match to the end of the string
-                    auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
+                    // get from the first matched capturing group to the end of the string
+                    size_t start = std::string::npos;
+                    for (auto i = 1u; i < match.size(); i++) {
+                        if (match.length(i) > 0) {
+                            start = match.position(i);
+                            break;
+                        }
+                    }
+                    if (start == std::string::npos) {
+                        start = match.position(0);
+                    }
+                    auto constrained_str = grammar.trigger_buffer.substr(start);
                     // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
                     grammar.trigger_buffer.clear();
                     llama_grammar_accept_str(grammar, constrained_str);

+ 3 - 1
tests/CMakeLists.txt

@@ -142,8 +142,10 @@ if (NOT WIN32)
     # llama_build_and_test(test-double-float.cpp) # SLOW
 endif()
 
-llama_build_and_test(test-log.cpp)
+llama_build_and_test(test-chat-parser.cpp)
 llama_build_and_test(test-chat-template.cpp)
+llama_build_and_test(test-json-partial.cpp)
+llama_build_and_test(test-log.cpp)
 llama_build_and_test(test-regex-partial.cpp)
 
 # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)

+ 355 - 0
tests/test-chat-parser.cpp

@@ -0,0 +1,355 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include <exception>
+#include <iostream>
+#include <json.hpp>
+#include <string>
+
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+using json = nlohmann::ordered_json;
+
+template <class T>
+static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+static void assert_equals(const char * expected, const std::string & actual) {
+  return assert_equals<std::string>(expected, actual);
+}
+
+static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
+    try {
+        fn();
+    } catch (const std::exception & e) {
+      if (expected_exception_pattern.empty()) {
+          return;
+        }
+        std::regex expected_exception_regex(expected_exception_pattern);
+        std::string actual_message = e.what();
+        if (std::regex_search(actual_message, expected_exception_regex)) {
+            return;
+        }
+        throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
+        throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
+    }
+    throw std::runtime_error("Exception was expected but not thrown");
+}
+
+static void test_reasoning() {
+  {
+    common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ true,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
+    assert_equals("<think>Cogito</think>", builder.result().content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+}
+
+static void test_regex() {
+  auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
+    common_chat_msg_parser builder(input, /* is_partial= */ false, {});
+    assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
+  };
+
+  test_throws("Hello, world!", "abc", "^abc$");
+  test_throws("Hello, world!", "e", "^e$");
+
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    builder.consume_regex(common_regex("Hello"));
+    assert_equals(", world!", builder.consume_rest());
+  }
+
+  {
+    // When in non partial mode, we can say whether the regex was consumed or not.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
+  }
+  {
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
+    assert_equals(true, res.has_value());
+    // Verify captures
+    assert_equals<size_t>(2, res->groups.size());
+    assert_equals("Hell", builder.str(res->groups[0]));
+    assert_equals("el", builder.str(res->groups[1]));
+    // Verify position is after the match
+    assert_equals<size_t>(4, builder.pos());
+    assert_equals("o,", builder.consume_rest());
+  }
+  {
+    // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
+    assert_throws([&]() {
+      builder.try_consume_regex(common_regex("Hello, world!"));
+    }, "^Hello, world!$");
+  }
+
+  // Now regardless of the mode, we can tell these aren't a match.
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
+  }
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_literal("Oh"));
+  }
+}
+
+const std::vector<std::string> barely_healable_jsons = {
+  "{",
+  "{\"",
+  "{\"\\",
+  "{\"n",
+  "{\"name\"",
+  "{\"name\":",
+  "{\"name\":\"",
+  "{\"name\":\"\\",
+  "{\"name\":\"python",
+  "{\"name\":\"python\\",
+  "{\",",
+  "{\":",
+  "{\"[",
+  "{\"]",
+  "{\"{",
+  "{\"}",
+  "{\"1",
+  "{\"name\":\",",
+  "{\"name\":\":",
+  "{\"name\":\"[",
+  "{\"name\":\"]",
+  "{\"name\":\"{",
+  "{\"name\":\"}",
+  "{\"name\":\"1",
+};
+
+static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
+  common_chat_msg_parser builder(input, is_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
+}
+static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
+  common_chat_msg_parser builder(input, parse_as_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, js->value.dump());
+}
+
+static void test_json_with_dumped_args_no_args() {
+  // Normal JSON, nothing to heal, nothing to dump
+  test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
+  // Full json is args
+  test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
+
+  // If the arguments are further down, don't heal partial content.
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{"arguments"}}, {}, "{}");
+  }
+  // But heal content that isn't partial.
+  test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
+}
+
+static void test_json_with_dumped_args() {
+
+  // Partial content.
+  test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
+  test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
+  test("{\"content\": ", true, {}, {{"content"}}, "{}");
+
+  // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
+  test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{}}, {}, src);
+  }
+
+  // Full JSON w/ args
+  for (auto parse_as_partial : {true, false}) {
+    test_with_args(
+      R"({"name": "python", "args": {"arg1": 1}})",
+      R"({"name":"python","args":"{\"arg1\":1}"})",
+      parse_as_partial,
+      /* is_partial= */ false
+    );
+  }
+
+  // Partial JSON w/ partial args
+  test_with_args(
+    R"({"foo": "bar", "args": {")",
+    R"({"foo":"bar","args":"{\""})"
+  );
+  // Partial args broken in object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"ar)",
+    R"({"foo":"bar","args":"{\"ar"})"
+  );
+  // Partial args broken after object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1")",
+    R"({"foo":"bar","args":"{\"arg1\""})"
+  );
+  // Partial args broken before object value
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1":)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken before object value (space)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": )",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that may not be complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1 )",
+    R"({"foo":"bar","args":"{\"arg1\":1"})"
+  );
+  // Partial args broken in object value that is incomplete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": ")",
+    R"({"foo":"bar","args":"{\"arg1\":\""})"
+  );
+  // Partial args broken in object value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": "1")",
+    R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
+  );
+  // Partial args broken on array opening
+  test_with_args(
+    R"({"foo": "bar", "args": [)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is incomplete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1 )",
+    R"({"foo":"bar","args":"[1"})"
+  );
+  // Partial args broken on array value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": ["1")",
+    R"({"foo":"bar","args":"[\"1\""})"
+  );
+  // Partial args broken after array value
+  test_with_args(
+    R"({"foo": "bar", "args": [1,)",
+    R"({"foo":"bar","args":"[1,"})"
+  );
+  // Partial args broken on nested array
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": [)",
+    R"({"foo":"bar","args":"{\"arg1\":["})"
+  );
+}
+
+static void test_positions() {
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    assert_equals<size_t>(0, builder.pos());
+    assert_throws([&]() { builder.move_to(100); });
+    assert_equals<size_t>(0, builder.pos());
+    assert_throws([&]() { builder.move_back(1); });
+    assert_equals<size_t>(0, builder.pos());
+
+    builder.move_to(8);
+    assert_equals<size_t>(8, builder.pos());
+    builder.move_back(1);
+    assert_equals<size_t>(7, builder.pos());
+    assert_equals("world!", builder.consume_rest());
+
+    builder.move_to(0);
+    assert_equals<size_t>(0, builder.pos());
+
+    assert_throws([&]() { builder.finish(); });
+    assert_equals<size_t>(0, builder.pos());
+
+    builder.move_to(builder.input().size());
+    builder.finish();
+  }
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
+
+    builder.move_to(builder.input().size());
+    assert_equals<size_t>(builder.input().size(), builder.pos());
+    builder.finish();
+  }
+}
+
+int main() {
+    test_positions();
+    test_json_with_dumped_args_no_args();
+    test_json_with_dumped_args();
+    test_reasoning();
+    test_regex();
+    std::cout << "All tests passed!\n";
+    return 0;
+}

File diff ditekan karena terlalu besar
+ 597 - 264
tests/test-chat.cpp


+ 237 - 0
tests/test-json-partial.cpp

@@ -0,0 +1,237 @@
+#include "common.h"
+#include "json-partial.h"
+#include <exception>
+#include <iostream>
+#include <stdexcept>
+
+template <class T> static void assert_equals(const T & expected, const T & actual) {
+  if (expected != actual) {
+      std::cerr << "Expected: " << expected << std::endl;
+      std::cerr << "Actual: " << actual << std::endl;
+      std::cerr << std::flush;
+      throw std::runtime_error("Test failed");
+  }
+}
+
+static void test_json_healing() {
+  auto parse = [](const std::string & str) {
+      std::cerr << "# Parsing: " << str << '\n';
+      std::string::const_iterator it = str.begin();
+      const auto end = str.end();
+      common_json out;
+      std::string healing_marker = "$llama.cpp.json$";
+      if (common_json_parse(it, end, healing_marker, out)) {
+          auto dump = out.json.dump();
+          std::cerr << "Parsed: " << dump << '\n';
+          std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
+          std::string result;
+          if (!out.healing_marker.json_dump_marker.empty()) {
+              auto i = dump.find(out.healing_marker.json_dump_marker);
+              if (i == std::string::npos) {
+                  throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
+              }
+              result = dump.substr(0, i);
+          } else {
+            result = dump;
+          }
+          std::cerr << "Result: " << result << '\n';
+          if (string_starts_with(str, result)) {
+            std::cerr << "Failure!\n";
+          }
+        //   return dump;
+      } else {
+        throw std::runtime_error("Failed to parse: " + str);
+      }
+
+  };
+  auto parse_all = [&](const std::string & str) {
+      for (size_t i = 1; i < str.size(); i++) {
+          parse(str.substr(0, i));
+      }
+  };
+  parse_all("{\"a\": \"b\"}");
+  parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
+
+  parse_all("[{\"a\": \"b\"}]");
+
+  auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
+      for (const auto & input : inputs) {
+        common_json out;
+        assert_equals(true, common_json_parse(input, "$foo", out));
+        assert_equals<std::string>(expected, out.json.dump());
+        assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
+      }
+  };
+  // No healing needed:
+  test(
+    {
+      R"([{"a":"b"}, "y"])",
+    },
+    R"([{"a":"b"},"y"])",
+    ""
+  );
+  // Partial literals can't be healed:
+  test(
+    {
+      R"([1)",
+      R"([tru)",
+      R"([n)",
+      R"([nul)",
+      R"([23.2)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({"a": 1)",
+      R"({"a": tru)",
+      R"({"a": n)",
+      R"({"a": nul)",
+      R"({"a": 23.2)",
+    },
+    R"({"a":"$foo"})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({)",
+    },
+    R"({"$foo":1})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  // Healing right after a full literal
+  test(
+    {
+      R"(1 )",
+    },
+    R"(1)",
+    ""
+  );
+  test(
+    {
+      R"(true)",
+      R"(true )",
+    },
+    R"(true)",
+    ""
+  );
+  test(
+    {
+      R"(null)",
+      R"(null )",
+    },
+    R"(null)",
+    ""
+  );
+  test(
+    {
+      R"([1 )",
+    },
+    R"([1,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{})",
+      R"([{} )",
+    },
+    R"([{},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true)",
+    },
+    // TODO: detect the true/false/null literal was complete
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([true )",
+    },
+    R"([true,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true,)",
+    },
+    R"([true,"$foo"])",
+    R"("$foo)"
+  );
+  // Test nesting
+  test(
+    {
+      R"([{"a": [{"b": [{)",
+    },
+    R"([{"a":[{"b":[{"$foo":1}]}]}])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([{"a": [{"b": [)",
+    },
+    R"([{"a":[{"b":["$foo"]}]}])",
+    R"("$foo)"
+  );
+
+  test(
+    {
+      R"([{"a": "b"})",
+      R"([{"a": "b"} )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{"a": "b"},)",
+      R"([{"a": "b"}, )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({ "code)",
+    },
+    R"({"code$foo":1})",
+    R"($foo)"
+  );
+  test(
+    {
+      R"({ "code\)",
+    },
+    R"({"code\\$foo":1})",
+    R"(\$foo)"
+  );
+  test(
+    {
+      R"({ "code")",
+    },
+    R"({"code":"$foo"})",
+    R"(:"$foo)"
+  );
+  test(
+    {
+      R"({ "key")",
+    },
+    R"({"key":"$foo"})",
+    R"(:"$foo)"
+  );
+}
+
+int main() {
+    test_json_healing();
+    std::cerr << "All tests passed.\n";
+    return 0;
+}

+ 139 - 142
tools/server/server.cpp

@@ -1,3 +1,4 @@
+#include "chat.h"
 #include "utils.hpp"
 
 #include "arg.h"
@@ -114,11 +115,11 @@ struct slot_params {
     struct common_params_speculative speculative;
 
     // OAI-compat fields
-    bool                  verbose                   = false;
-    oaicompat_type        oaicompat                 = OAICOMPAT_TYPE_NONE;
-    std::string           oaicompat_model;
-    std::string           oaicompat_cmpl_id;
-    common_chat_format    oaicompat_chat_format     = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    bool                         verbose                   = false;
+    oaicompat_type               oaicompat                 = OAICOMPAT_TYPE_NONE;
+    std::string                  oaicompat_model;
+    std::string                  oaicompat_cmpl_id;
+    common_chat_syntax           oaicompat_chat_syntax;
 
     json to_json() const {
         std::vector<std::string> samplers;
@@ -176,7 +177,10 @@ struct slot_params {
             {"grammar_lazy",              sampling.grammar_lazy},
             {"grammar_triggers",          grammar_triggers},
             {"preserved_tokens",          sampling.preserved_tokens},
-            {"chat_format",               common_chat_format_name(oaicompat_chat_format)},
+            {"chat_format",               common_chat_format_name(oaicompat_chat_syntax.format)},
+            {"reasoning_format",          (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
+            {"reasoning_in_content",      oaicompat_chat_syntax.reasoning_in_content},
+            {"thinking_forced_open",      oaicompat_chat_syntax.thinking_forced_open},
             {"samplers",                  samplers},
             {"speculative.n_max",         speculative.n_max},
             {"speculative.n_min",         speculative.n_min},
@@ -352,11 +356,14 @@ struct server_task {
         {
             auto it = data.find("chat_format");
             if (it != data.end()) {
-                params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
-                SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
+                params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
+                SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
             } else {
-                params.oaicompat_chat_format = defaults.oaicompat_chat_format;
+                params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
             }
+            params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
+            params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
+            params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
         }
 
         {
@@ -396,7 +403,14 @@ struct server_task {
                             params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
                         }
                     } else {
-                        params.sampling.grammar_triggers.push_back(std::move(ct.value));
+                        if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
+                            SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
+                        } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
+                            SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
+                        } else {
+                            throw std::runtime_error("Unknown grammar trigger type");
+                        }
+                        params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
                     }
                 }
             }
@@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
     slot_params generation_params;
 
     // OAI-compat fields
-    bool                  verbose                  = false;
-    oaicompat_type        oaicompat                = OAICOMPAT_TYPE_NONE;
-    std::string           oaicompat_model;
-    std::string           oaicompat_cmpl_id;
-    common_chat_format    oaicompat_chat_format    = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    bool               verbose                  = false;
+    oaicompat_type     oaicompat                = OAICOMPAT_TYPE_NONE;
+    std::string        oaicompat_model;
+    std::string        oaicompat_cmpl_id;
+    common_chat_msg    oaicompat_msg;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
 
     virtual int get_index() override {
         return index;
@@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
     json to_json_oaicompat_chat() {
         std::string finish_reason = "length";
         common_chat_msg msg;
-        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
-            SRV_DBG("Parsing chat message: %s\n", content.c_str());
-            msg = common_chat_parse(content, oaicompat_chat_format);
-            finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
+        if (!oaicompat_msg.empty()) {
+            msg = oaicompat_msg;
         } else {
+            msg.role = "assistant";
             msg.content = content;
         }
-
-        json message {
-            {"role", "assistant"},
-        };
-        if (!msg.reasoning_content.empty()) {
-            message["reasoning_content"] = msg.reasoning_content;
-        }
-        if (msg.content.empty() && !msg.tool_calls.empty()) {
-            message["content"] = json();
-        } else {
-            message["content"] = msg.content;
-        }
-        if (!msg.tool_calls.empty()) {
-            auto tool_calls = json::array();
-            for (const auto & tc : msg.tool_calls) {
-                tool_calls.push_back({
-                    {"type", "function"},
-                    {"function", {
-                        {"name", tc.name},
-                        {"arguments", tc.arguments},
-                    }},
-                    // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
-                    // We only generate a random id for the ones that don't generate one by themselves
-                    // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
-                    {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
-                });
-            }
-            message["tool_calls"] = tool_calls;
+        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+            finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
         }
 
         json choice {
             {"finish_reason", finish_reason},
             {"index", 0},
-            {"message", message},
+            {"message", msg.to_json_oaicompat<json>()},
         };
 
         if (!stream && probs_output.size() > 0) {
@@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
         std::time_t t = std::time(0);
         std::string finish_reason = "length";
         if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
-            finish_reason = "stop";
+            finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
+        }
+
+        json deltas = json::array();
+        for (const auto & diff : oaicompat_msg_diffs) {
+            deltas.push_back({
+                {"choices", json::array({
+                    json {
+                        {"finish_reason", nullptr},
+                        {"index", 0},
+                        {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
+                    },
+                })},
+                {"created", t},
+                {"id", oaicompat_cmpl_id},
+                {"model", oaicompat_model},
+                {"system_fingerprint", build_info},
+                {"object", "chat.completion.chunk"},
+            });
         }
 
-        json choice = json {
-            {"finish_reason", finish_reason},
-            {"index", 0},
-            {"delta", json::object()}
-        };
-
-        json ret = json {
-            {"choices",            json::array({choice})},
+        deltas.push_back({
+            {"choices", json::array({
+                json {
+                    {"finish_reason", finish_reason},
+                    {"index", 0},
+                    {"delta", json::object()},
+                },
+            })},
             {"created",            t},
             {"id",                 oaicompat_cmpl_id},
             {"model",              oaicompat_model},
@@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
                 {"prompt_tokens",     n_prompt_tokens},
                 {"total_tokens",      n_decoded + n_prompt_tokens},
             }},
-        };
+        });
 
         if (timings.prompt_n >= 0) {
-            ret.push_back({"timings", timings.to_json()});
+            deltas.back().push_back({"timings", timings.to_json()});
         }
 
         // extra fields for debugging purposes
-        if (verbose) {
-            ret["__verbose"] = to_json_non_oaicompat();
+        if (verbose && !deltas.empty()) {
+            deltas.front()["__verbose"] = to_json_non_oaicompat();
         }
 
-        return ret;
+        return deltas;
     }
 };
 
@@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
     result_timings timings;
 
     // OAI-compat fields
-    bool           verbose   = false;
-    oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
-    std::string    oaicompat_model;
-    std::string    oaicompat_cmpl_id;
+    bool            verbose   = false;
+    oaicompat_type  oaicompat = OAICOMPAT_TYPE_NONE;
+    std::string     oaicompat_model;
+    std::string     oaicompat_cmpl_id;
+    std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
 
     virtual int get_index() override {
         return index;
@@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
         std::time_t t = std::time(0);
         json choices;
 
+        std::vector<json> deltas;
+        auto add_delta = [&](const json & delta) {
+            deltas.push_back({
+                {"choices", json::array({
+                    json {
+                        {"finish_reason", nullptr},
+                        {"index", 0},
+                        {"delta", delta},
+                    },
+                })},
+                {"created", t},
+                {"id", oaicompat_cmpl_id},
+                {"model", oaicompat_model},
+                {"system_fingerprint", build_info},
+                {"object", "chat.completion.chunk"},
+            });
+        };
+        // We have to send an initial update to conform to openai behavior
         if (first) {
-            if (content.empty()) {
-                choices = json::array({json{{"finish_reason", nullptr},
-                                            {"index", 0},
-                                            {"delta", json{{"role", "assistant"}}}}});
-            } else {
-                // We have to send this as two updates to conform to openai behavior
-                // initial_ret is the role message for stream=True
-                json initial_ret = json{{"choices", json::array({json{
-                                        {"finish_reason", nullptr},
-                                        {"index", 0},
-                                        {"delta", json{
-                                            {"role", "assistant"},
-                                            {"content", ""}
-                                        }}}})},
-                            {"created", t},
-                            {"id", oaicompat_cmpl_id},
-                            {"model", oaicompat_model},
-                            {"system_fingerprint", build_info},
-                            {"object", "chat.completion.chunk"}};
-
-                json second_ret = json{
-                            {"choices", json::array({json{{"finish_reason", nullptr},
-                                                            {"index", 0},
-                                                            {"delta", json {
-                                                            {"content", content}}}
-                                                            }})},
-                            {"created", t},
-                            {"id", oaicompat_cmpl_id},
-                            {"model", oaicompat_model},
-                            {"system_fingerprint", build_info},
-                            {"object", "chat.completion.chunk"}};
-
-                if (prob_output.probs.size() > 0) {
-                    second_ret["choices"][0]["logprobs"] = json{
-                        {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
-                    };
-                }
-
-                if (timings.prompt_n >= 0) {
-                    second_ret.push_back({"timings", timings.to_json()});
-                }
-
-                return std::vector<json>({initial_ret, second_ret});
-            }
-        } else {
-            choices = json::array({json{
-                {"finish_reason", nullptr},
-                {"index", 0},
-                {"delta",
-                json {
-                    {"content", content},
-                }},
-            }});
+            add_delta({
+                {"role", "assistant"},
+                {"content", nullptr},
+            });
         }
 
-        GGML_ASSERT(choices.size() >= 1);
-
-        if (prob_output.probs.size() > 0) {
-            choices[0]["logprobs"] = json{
-                {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
-            };
+        for (const auto & diff : oaicompat_msg_diffs) {
+            add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
         }
 
-        json ret = json {
-            {"choices",            choices},
-            {"created",            t},
-            {"id",                 oaicompat_cmpl_id},
-            {"model",              oaicompat_model},
-            {"system_fingerprint", build_info},
-            {"object",             "chat.completion.chunk"}
-        };
+        if (!deltas.empty()) {
+            GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
 
-        if (timings.prompt_n >= 0) {
-            ret.push_back({"timings", timings.to_json()});
+            if (prob_output.probs.size() > 0) {
+                deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
+                    {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+                };
+            }
+
+            if (timings.prompt_n >= 0) {
+                deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
+            }
         }
 
-        return std::vector<json>({ret});
+        return deltas;
     }
 };
 
@@ -1293,6 +1266,7 @@ struct server_slot {
 
     std::string  generated_text;
     llama_tokens generated_tokens;
+    common_chat_msg chat_msg;
 
     server_tokens cache_tokens;
 
@@ -1313,6 +1287,7 @@ struct server_slot {
     llama_token sampled;
 
     common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+    std::vector<std::string> generated_tool_call_ids;
 
     // stats
     size_t n_sent_text        = 0; // number of sent text character
@@ -1342,9 +1317,13 @@ struct server_slot {
         n_past             = 0;
         n_sent_text        = 0;
         task_type          = SERVER_TASK_TYPE_COMPLETION;
+        chat_format        = COMMON_CHAT_FORMAT_CONTENT_ONLY;
 
         generated_tokens.clear();
         generated_token_probs.clear();
+        chat_msg = {};
+        json_schema = json();
+        generated_tool_call_ids.clear();
 
         // clear speculative decoding stats
         n_draft_total = 0;
@@ -1424,6 +1403,21 @@ struct server_slot {
         return timings;
     }
 
+    const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
+        auto previous_msg = chat_msg;
+        SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
+        auto new_msg = common_chat_parse(
+            generated_text,
+            /* is_partial= */ stop != STOP_TYPE_EOS,
+            params.oaicompat_chat_syntax);
+        if (!new_msg.empty()) {
+            new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
+            chat_msg = new_msg;
+            diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
+        }
+        return chat_msg;
+    }
+
     size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
         size_t stop_pos = std::string::npos;
 
@@ -2475,10 +2469,12 @@ struct server_context {
         res->n_prompt_tokens     = slot.n_prompt_tokens;
         res->post_sampling_probs = slot.params.post_sampling_probs;
 
-        res->verbose           = slot.params.verbose;
-        res->oaicompat         = slot.params.oaicompat;
-        res->oaicompat_model   = slot.params.oaicompat_model;
-        res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+        res->verbose               = slot.params.verbose;
+        res->oaicompat             = slot.params.oaicompat;
+        res->oaicompat_model       = slot.params.oaicompat_model;
+        res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
+
+        slot.update_chat_msg(res->oaicompat_msg_diffs);
 
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
@@ -2499,7 +2495,7 @@ struct server_context {
         res->id_slot         = slot.id;
 
         res->index           = slot.index;
-        res->content         = std::move(slot.generated_text);
+        res->content         = slot.generated_text;
         res->tokens          = std::move(slot.generated_tokens);
         res->timings         = slot.get_timings();
         res->prompt          = slot.prompt_tokens.detokenize(ctx, true);
@@ -2519,7 +2515,8 @@ struct server_context {
         res->oaicompat             = slot.params.oaicompat;
         res->oaicompat_model       = slot.params.oaicompat_model;
         res->oaicompat_cmpl_id     = slot.params.oaicompat_cmpl_id;
-        res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
+        res->oaicompat_msg         = slot.update_chat_msg(res->oaicompat_msg_diffs);
+
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
             if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {

+ 5 - 4
tools/server/tests/unit/test_chat_completion.py

@@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
         choice = data["choices"][0]
         if i == 0:
             # Check first role message for stream=True
-            assert choice["delta"]["content"] == ""
+            assert choice["delta"]["content"] is None
             assert choice["delta"]["role"] == "assistant"
         else:
             assert "role" not in choice["delta"]
@@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
             assert choice["finish_reason"] == finish_reason
         else:
             assert choice["finish_reason"] is None
-            content += choice["delta"]["content"]
+            content += choice["delta"]["content"] or ''
 
 
 def test_chat_completion_with_openai_library():
@@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
     for i, data in enumerate(res):
         if i == 0:
             # Check first role message for stream=True
-            assert data["choices"][0]["delta"]["content"] == ""
+            assert data["choices"][0]["delta"]["content"] is None
             assert data["choices"][0]["delta"]["role"] == "assistant"
+            assert "timings" not in data, f'First event should not have timings: {data}'
         else:
             assert "role" not in data["choices"][0]["delta"]
             assert "timings" in data
@@ -311,7 +312,7 @@ def test_logprobs_stream():
         choice = data.choices[0]
         if i == 0:
             # Check first role message for stream=True
-            assert choice.delta.content == ""
+            assert choice.delta.content is None
             assert choice.delta.role == "assistant"
         else:
             assert choice.delta.role is None

+ 84 - 66
tools/server/tests/unit/test_tool_call.py

@@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
 sys.path.insert(0, str(path))
 
 from utils import *
+from enum import Enum
 
 server: ServerProcess
 
@@ -20,7 +21,11 @@ def create_server():
     server = ServerPreset.tinyllama2()
     server.model_alias = "tinyllama-2-tool-call"
     server.server_port = 8081
+    server.n_slots = 1
 
+class CompletionMode(Enum):
+    NORMAL = "normal"
+    STREAMED = "streamed"
 
 TEST_TOOL = {
     "type":"function",
@@ -73,9 +78,8 @@ WEATHER_TOOL = {
   }
 }
 
-
 def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
         "parallel_tool_calls": False,
         **kwargs,
     })
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    # assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
     assert expected_function_name == tool_call["function"]["name"]
     actual_arguments = tool_call["function"]["arguments"]
@@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
         assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
 
 
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,tool,argument_key", [
     ("google-gemma-2-2b-it",                          TEST_TOOL,            "success"),
+    ("google-gemma-2-2b-it",                          TEST_TOOL,            "success"),
+    ("meta-llama-Llama-3.3-70B-Instruct",             TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.3-70B-Instruct",             TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.3-70B-Instruct",             PYTHON_TOOL,          "code"),
+    ("meta-llama-Llama-3.3-70B-Instruct",             PYTHON_TOOL,          "code"),
 ])
-def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
+def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
     global server
     n_predict = 1024
     # server = ServerPreset.stories15m_moe()
@@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,tool,argument_key", [
     ("meta-llama-Llama-3.1-8B-Instruct",              TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.1-8B-Instruct",              PYTHON_TOOL,          "code"),
+
     ("meetkai-functionary-medium-v3.1",               TEST_TOOL,            "success"),
     ("meetkai-functionary-medium-v3.1",               PYTHON_TOOL,          "code"),
+
     ("meetkai-functionary-medium-v3.2",               TEST_TOOL,            "success"),
-    ("meetkai-functionary-medium-v3.2",               PYTHON_TOOL,          "code"),
+    # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
+    # ("meetkai-functionary-medium-v3.2",               PYTHON_TOOL,          "code"),
+
     ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL,            "success"),
     ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL,          "code"),
+
     ("meta-llama-Llama-3.2-3B-Instruct",              TEST_TOOL,            "success"),
     ("meta-llama-Llama-3.2-3B-Instruct",              PYTHON_TOOL,          "code"),
+
     ("mistralai-Mistral-Nemo-Instruct-2407",          TEST_TOOL,            "success"),
     ("mistralai-Mistral-Nemo-Instruct-2407",          PYTHON_TOOL,          "code"),
+
     ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use",   TEST_TOOL,            "success"),
     ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use",   PYTHON_TOOL,          "code"),
+
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      TEST_TOOL,            "success"),
     ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B",      PYTHON_TOOL,          "code"),
+
     ("fireworks-ai-llama-3-firefunction-v2",          TEST_TOOL,            "success"),
+    # ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "codeFalse), True),
     # ("fireworks-ai-llama-3-firefunction-v2",          PYTHON_TOOL,          "code"),
+
 ])
-def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
+def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
     global server
     n_predict = 512
     # server = ServerPreset.stories15m_moe()
@@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     server.n_predict = n_predict
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
+    do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
     (TEST_TOOL,    "success",  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
@@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     (PYTHON_TOOL,  "code",     "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",   "chatml"),
 
-    (TEST_TOOL,    "success",  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+    # (TEST_TOOL,    "success",  "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    # (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    # (PYTHON_TOOL,  "code",     "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
 
     (TEST_TOOL,    "success",  "bartowski/functionary-small-v3.2-GGUF:Q4_K_M",       ("meetkai/functionary-medium-v3.2", None)),
     (PYTHON_TOOL,  "code",     "bartowski/functionary-small-v3.2-GGUF:Q4_K_M",       ("meetkai/functionary-medium-v3.2", None)),
@@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
     (TEST_TOOL,    "success",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
     (PYTHON_TOOL,  "code",     "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
-def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
         "tool_choice": "required",
         "tools": [tool],
         "parallel_tool_calls": False,
+        "stream": stream == CompletionMode.STREAMED,
         "temperature": 0.0,
         "top_k": 1,
         "top_p": 1.0,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
@@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
 
 
 def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a coding assistant."},
@@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
         "tool_choice": tool_choice,
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
 
 
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [],            None),
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [TEST_TOOL],   None),
     ("meta-llama-Llama-3.3-70B-Instruct",         128, [PYTHON_TOOL], 'none'),
 ])
-def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
+def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
     global server
-    server.jinja = True
     server.n_predict = n_predict
+    server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
     ("meetkai-functionary-medium-v3.2",               256, [],            None),
     ("meetkai-functionary-medium-v3.2",               256, [TEST_TOOL],   None),
@@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
     ("meta-llama-Llama-3.2-3B-Instruct",              256, [TEST_TOOL],   None),
     ("meta-llama-Llama-3.2-3B-Instruct",              256, [PYTHON_TOOL], 'none'),
 ])
-def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
+def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
     global server
-    server.jinja = True
     server.n_predict = n_predict
+    server.jinja = True
     server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
+    do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("hf_repo,template_override", [
     ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
     ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
     ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
     ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M",      "chatml"),
 
-    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
-    ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
+    # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
+    # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
 
-    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)),
-    ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
+    # ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       ("meetkai/functionary-medium-v3.2", None)),
+    # ("bartowski/functionary-small-v3.2-GGUF:Q8_0",       "chatml"),
 
     ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      ("meta-llama/Llama-3.2-3B-Instruct", None)),
     ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",      "chatml"),
@@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
 
     # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
 ])
-def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_weather(server, max_tokens=n_predict)
+    do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
 
 def do_test_weather(server: ServerProcess, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "messages": [
             {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
             {"role": "user", "content": "What is the weather in Istanbul?"},
@@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
         "tools": [WEATHER_TOOL],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     actual_arguments = json.loads(tool_call["function"]["arguments"])
     assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
     location = actual_arguments["location"]
@@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
     (None,                                           128,  "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       "chatml"),
     (None,                                           128,  "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
@@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
     # (None,                                           128,  "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M",  None),
     # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)",            8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 ])
-def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192 * 2
     server.n_predict = n_predict
@@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    do_test_calc_result(server, result_override, n_predict)
+    do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
 
 
 def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
@@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
         ],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
     content = choice["message"].get("content")
@@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
 
 
 @pytest.mark.slow
-@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
-    (128, 'deepseek',  "^The sum of 102 and 7 is 109[\\s\\S]*",                        None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
-    (128,  None,        "^The sum of 102 and 7 is 109[\\s\\S]*",                       None,                                          "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
-
-    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "I need to calculate the sum of 102 and 7[\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-    (1024, 'none',      "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*",                None,                                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
-
-    (1024, 'deepseek',  "To find the sum of[\\s\\S]*",                                 "First, I [\\s\\S]*",                          "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
+    (128, 'deepseek',   CompletionMode.NORMAL,   None, "^The sum of 102 and 7 is 109[\\s\\S]*",                                       "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (128,  None,        CompletionMode.NORMAL,   None, "^The sum of 102 and 7 is 109[\\s\\S]*",                                       "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M",       None),
+    (1024, 'deepseek',  CompletionMode.NORMAL,   "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'deepseek',  CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*",  "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    (1024, 'deepseek',  CompletionMode.NORMAL,   "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*",                                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    (1024, 'deepseek',  CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*",              "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
+    # (1024, 'none',      CompletionMode.NORMAL,   None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*",                 "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
+    # (128,  'deepseek',  None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*",                      "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M",                None),
 ])
-def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
-    server.n_slots = 1
     server.reasoning_format = reasoning_format
     server.jinja = True
     server.n_ctx = 8192 * 2
@@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
     elif isinstance(template_override, str):
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "max_tokens": n_predict,
         "messages": [
             {"role": "user", "content": "What's the sum of 102 and 7?"},
-        ]
+        ],
+        "stream": stream == CompletionMode.STREAMED,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
 
     content = choice["message"].get("content")
@@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
 
 
 @pytest.mark.slow
+@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
 @pytest.mark.parametrize("hf_repo,template_override", [
     ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
 
@@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
     ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              None),
     ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M",              "chatml"),
 ])
-def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
+def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
     global server
     n_predict = 512 # High because of DeepSeek R1
-    server.n_slots = 1
     server.jinja = True
     server.n_ctx = 8192
     server.n_predict = n_predict
@@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
         server.chat_template = template_override
     server.start(timeout_seconds=TIMEOUT_SERVER_START)
 
-    do_test_hello_world(server, max_tokens=n_predict)
+    do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
 
 
 def do_test_hello_world(server: ServerProcess, **kwargs):
-    res = server.make_request("POST", "/v1/chat/completions", data={
+    body = server.make_any_request("POST", "/v1/chat/completions", data={
         "messages": [
             {"role": "system", "content": "You are a tool-calling agent."},
             {"role": "user", "content": "say hello world with python"},
@@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
         "tools": [PYTHON_TOOL],
         **kwargs,
     }, timeout=TIMEOUT_HTTP_REQUEST)
-    assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
-    choice = res.body["choices"][0]
+    choice = body["choices"][0]
     tool_calls = choice["message"].get("tool_calls")
     assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
     tool_call = tool_calls[0]
     # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
     assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
-    assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
+    # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
     actual_arguments = json.loads(tool_call["function"]["arguments"])
     assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
     code = actual_arguments["code"]
     assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
-    assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
+    assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'

+ 71 - 0
tools/server/tests/utils.py

@@ -294,6 +294,77 @@ class ServerProcess:
                 print("Partial response from server", json.dumps(data, indent=2))
                 yield data
 
+    def make_any_request(
+        self,
+        method: str,
+        path: str,
+        data: dict | None = None,
+        headers: dict | None = None,
+        timeout: float | None = None,
+    ) -> dict:
+        stream = data.get('stream', False)
+        if stream:
+            content: list[str] = []
+            tool_calls: list[dict] = []
+            finish_reason: Optional[str] = None
+
+            content_parts = 0
+            tool_call_parts = 0
+            arguments_parts = 0
+
+            for chunk in self.make_stream_request(method, path, data, headers):
+                assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
+                choice = chunk['choices'][0]
+                if choice['delta'].get('content') is not None:
+                    assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
+                    content.append(choice['delta']['content'])
+                    content_parts += 1
+                if choice['delta'].get('finish_reason') is not None:
+                    finish_reason = choice['delta']['finish_reason']
+                for tc in choice['delta'].get('tool_calls', []):
+                    if 'function' not in tc:
+                        raise ValueError(f"Expected function type, got {tc['type']}")
+                    if tc['index'] >= len(tool_calls):
+                        tool_calls.append(dict(
+                            id="",
+                            type="function",
+                            function=dict(
+                                name="",
+                                arguments="",
+                            )
+                        ))
+                    tool_call = tool_calls[tc['index']]
+                    if tc.get('id') is not None:
+                        tool_call['id'] = tc['id']
+                    fct = tc['function']
+                    if fct.get('name') is not None:
+                        tool_call['function']['name'] = fct['name']
+                    if fct.get('arguments') is not None:
+                        assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
+                        tool_call['function']['arguments'] += fct['arguments']
+
+            print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
+            result = dict(
+                choices=[
+                    dict(
+                        index=0,
+                        finish_reason=finish_reason,
+                        message=dict(
+                            role='assistant',
+                            content=''.join(content) if content else None,
+                            tool_calls=tool_calls if tool_calls else None,
+                        ),
+                    )
+                ],
+            )
+            print("Final response from server", json.dumps(result, indent=2))
+            return result
+        else:
+            response = self.make_request(method, path, data, headers, timeout=timeout)
+            assert response.status_code == 200, f"Server returned error: {response.status_code}"
+            return response.body
+
+
 
 server_instances: Set[ServerProcess] = set()
 

+ 14 - 34
tools/server/utils.hpp

@@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
 // other common utils
 //
 
-static bool ends_with(const std::string & str, const std::string & suffix) {
-    return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
-}
-
-static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
-    if (!text.empty() && !stop.empty()) {
-        const char text_last_char = text.back();
-        for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
-            if (stop[char_index] == text_last_char) {
-                const std::string current_partial = stop.substr(0, char_index + 1);
-                if (ends_with(text, current_partial)) {
-                    return text.size() - char_index - 1;
-                }
-            }
-        }
-    }
-
-    return std::string::npos;
-}
-
 // TODO: reuse llama_detokenize
 template <class Iter>
 static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -599,19 +579,16 @@ static json oaicompat_chat_params_parse(
     json llama_params;
 
     auto tools = json_value(body, "tools", json());
+    auto has_tools = tools.is_array() && !tools.empty();
     auto stream = json_value(body, "stream", false);
+    auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
 
-    if (tools.is_array() && !tools.empty()) {
-        if (stream) {
-            throw std::runtime_error("Cannot use tools with stream");
-        }
-        if (!opt.use_jinja) {
+    if (!opt.use_jinja) {
+        if (has_tools) {
             throw std::runtime_error("tools param requires --jinja flag");
         }
-    }
-    if (!opt.use_jinja) {
-        if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
-            throw std::runtime_error("Unsupported param: tool_choice");
+        if (tool_choice != "auto") {
+            throw std::runtime_error("tool_choice param requires --jinja flag");
         }
     }
 
@@ -749,14 +726,12 @@ static json oaicompat_chat_params_parse(
     common_chat_templates_inputs inputs;
     inputs.messages              = common_chat_msgs_parse_oaicompat(messages);
     inputs.tools                 = common_chat_tools_parse_oaicompat(tools);
-    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+    inputs.tool_choice           = common_chat_tool_choice_parse_oaicompat(tool_choice);
     inputs.json_schema           = json_schema.is_null() ? "" : json_schema.dump();
     inputs.grammar               = grammar;
-    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
     inputs.use_jinja             = opt.use_jinja;
     inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
-    inputs.extract_reasoning     = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
-    inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
+    inputs.reasoning_format      = opt.reasoning_format;
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
         throw std::runtime_error("Cannot use custom grammar constraints with tools.");
     }
@@ -774,7 +749,8 @@ static json oaicompat_chat_params_parse(
             throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
         }
 
-        inputs.extract_reasoning = false;
+        /* TODO: test this properly */
+        inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
         inputs.add_generation_prompt = true;
     }
 
@@ -799,6 +775,7 @@ static json oaicompat_chat_params_parse(
     }
     llama_params["grammar_triggers"] = grammar_triggers;
     llama_params["preserved_tokens"] = chat_params.preserved_tokens;
+    llama_params["thinking_forced_open"]     = chat_params.thinking_forced_open;
     for (const auto & stop : chat_params.additional_stops) {
         llama_params["stop"].push_back(stop);
     }
@@ -812,6 +789,9 @@ static json oaicompat_chat_params_parse(
     // Handle "logprobs" field
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     if (json_value(body, "logprobs", false)) {
+        if (has_tools && stream) {
+            throw std::runtime_error("logprobs is not supported with tools + stream");
+        }
         llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
     } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
         throw std::runtime_error("top_logprobs requires logprobs to be set to true");

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini