|
@@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) {
|
|
|
"assistant",
|
|
"assistant",
|
|
|
"",
|
|
"",
|
|
|
{},
|
|
{},
|
|
|
|
|
+ /* .tool_plan = */ "",
|
|
|
};
|
|
};
|
|
|
if (message.contains("content") && !message.at("content").is_null()) {
|
|
if (message.contains("content") && !message.at("content").is_null()) {
|
|
|
- ret.content = message.at("content").get<std::string>();
|
|
|
|
|
|
|
+ ret.content = message.at("content");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (message.contains("tool_plan")) {
|
|
|
|
|
+ ret.tool_plan = message.at("tool_plan");
|
|
|
}
|
|
}
|
|
|
auto has_tool_calls = message.contains("tool_calls");
|
|
auto has_tool_calls = message.contains("tool_calls");
|
|
|
if (has_tool_calls) {
|
|
if (has_tool_calls) {
|
|
@@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too
|
|
|
|
|
|
|
|
struct delta_data {
|
|
struct delta_data {
|
|
|
std::string delta;
|
|
std::string delta;
|
|
|
- std::string grammar;
|
|
|
|
|
- common_chat_format format;
|
|
|
|
|
|
|
+ common_chat_params params;
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
@@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- return { delta, params_full.grammar, params_full.format };
|
|
|
|
|
|
|
+ return { delta, params_full };
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
/*
|
|
@@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|
|
*/
|
|
*/
|
|
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
|
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
|
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
|
|
- bool skip_grammar_test = false, bool skip_parser_test = false) {
|
|
|
|
|
|
|
+ bool expect_grammar_triggered = true) {
|
|
|
common_chat_msg expected_msg = msg_from_json(test_message);
|
|
common_chat_msg expected_msg = msg_from_json(test_message);
|
|
|
|
|
|
|
|
auto user_message = json{
|
|
auto user_message = json{
|
|
@@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|
|
assert_equals(expected_delta, data.delta);
|
|
assert_equals(expected_delta, data.delta);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (!skip_parser_test) {
|
|
|
|
|
- const auto msg = common_chat_parse(data.delta, data.format);
|
|
|
|
|
|
|
+ if (expect_grammar_triggered) {
|
|
|
|
|
+ const auto msg = common_chat_parse(data.delta, data.params.format);
|
|
|
assert_msg_equals(expected_msg, msg);
|
|
assert_msg_equals(expected_msg, msg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (!expected_msg.tool_calls.empty()) {
|
|
if (!expected_msg.tool_calls.empty()) {
|
|
|
- GGML_ASSERT(!data.grammar.empty());
|
|
|
|
|
|
|
+ GGML_ASSERT(!data.params.grammar.empty());
|
|
|
}
|
|
}
|
|
|
- if (!data.grammar.empty()) {
|
|
|
|
|
- auto grammar = build_grammar(data.grammar);
|
|
|
|
|
|
|
+ if (!data.params.grammar.empty()) {
|
|
|
|
|
+ auto grammar = build_grammar(data.params.grammar);
|
|
|
if (!grammar) {
|
|
if (!grammar) {
|
|
|
throw std::runtime_error("Failed to build grammar");
|
|
throw std::runtime_error("Failed to build grammar");
|
|
|
}
|
|
}
|
|
|
- // TODO: exercice lazy grammars + triggers here, instead of skipping the test
|
|
|
|
|
- if (!skip_grammar_test) {
|
|
|
|
|
- if (!match_string(data.delta, grammar.get())) {
|
|
|
|
|
- throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
|
|
|
|
- "\n\nGrammar: " + data.grammar);
|
|
|
|
|
|
|
+ auto earliest_trigger_pos = std::string::npos;
|
|
|
|
|
+ auto constrained = data.delta;
|
|
|
|
|
+ for (const auto & trigger : data.params.grammar_triggers) {
|
|
|
|
|
+ auto pos = constrained.find(trigger.word);
|
|
|
|
|
+ if (pos == std::string::npos) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (pos > 0 && trigger.at_start) {
|
|
|
|
|
+ fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
|
|
|
|
+ earliest_trigger_pos = pos;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ auto grammar_triggered = false;
|
|
|
|
|
+ if (earliest_trigger_pos != std::string::npos) {
|
|
|
|
|
+ constrained = constrained.substr(earliest_trigger_pos);
|
|
|
|
|
+ grammar_triggered = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (data.params.grammar_lazy) {
|
|
|
|
|
+ assert_equals(expect_grammar_triggered, grammar_triggered);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
|
|
|
|
+ throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
|
|
|
|
+ "\n\nGrammar: " + data.params.grammar);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static void test_template_output_parsers() {
|
|
static void test_template_output_parsers() {
|
|
|
- auto text_message = json{
|
|
|
|
|
|
|
+ json text_message {
|
|
|
{ "role", "assistant" },
|
|
{ "role", "assistant" },
|
|
|
{ "content", "Hello, world!" },
|
|
{ "content", "Hello, world!" },
|
|
|
};
|
|
};
|
|
|
- auto tool_call_message = json{
|
|
|
|
|
|
|
+ json tool_calls = json::array({{
|
|
|
|
|
+ { "type", "function" },
|
|
|
|
|
+ { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
|
|
|
|
+ }});
|
|
|
|
|
+
|
|
|
|
|
+ json tool_call_message {
|
|
|
|
|
+ { "role", "assistant"},
|
|
|
|
|
+ { "content", {}},
|
|
|
|
|
+ { "tool_calls", {
|
|
|
|
|
+ {
|
|
|
|
|
+ { "type", "function" },
|
|
|
|
|
+ { "function", {
|
|
|
|
|
+ { "name", "special_function" },
|
|
|
|
|
+ { "arguments", "{\"arg1\": 1}" },
|
|
|
|
|
+ }},
|
|
|
|
|
+ },
|
|
|
|
|
+ }},
|
|
|
|
|
+ };
|
|
|
|
|
+ json tool_call_message_with_id {
|
|
|
|
|
+ { "role", "assistant"},
|
|
|
|
|
+ { "content", {}},
|
|
|
|
|
+ { "tool_calls", {
|
|
|
|
|
+ {
|
|
|
|
|
+ { "type", "function" },
|
|
|
|
|
+ { "function", {
|
|
|
|
|
+ { "name", "special_function" },
|
|
|
|
|
+ { "arguments", "{\"arg1\": 1}" },
|
|
|
|
|
+ }},
|
|
|
|
|
+ {"id", "123456789"},
|
|
|
|
|
+ },
|
|
|
|
|
+ }},
|
|
|
{ "role", "assistant" },
|
|
{ "role", "assistant" },
|
|
|
{ "content", {} },
|
|
{ "content", {} },
|
|
|
- { "tool_calls", json{ {
|
|
|
|
|
- { "type", "function" },
|
|
|
|
|
- { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
|
|
|
|
- } } }
|
|
|
|
|
|
|
+ { "tool_calls", tool_calls }
|
|
|
|
|
+ };
|
|
|
|
|
+ json tool_call_plan_message_with_idx {
|
|
|
|
|
+ { "role", "assistant"},
|
|
|
|
|
+ { "content", {}},
|
|
|
|
|
+ { "tool_plan", "I'm not so sure"},
|
|
|
|
|
+ { "tool_calls", {
|
|
|
|
|
+ {
|
|
|
|
|
+ { "type", "function" },
|
|
|
|
|
+ { "function", {
|
|
|
|
|
+ { "name", "special_function" },
|
|
|
|
|
+ { "arguments", "{\"arg1\": 1}" },
|
|
|
|
|
+ }},
|
|
|
|
|
+ // Index of the tool call in the tool_calls array
|
|
|
|
|
+ {"id", "0"},
|
|
|
|
|
+ },
|
|
|
|
|
+ }},
|
|
|
|
|
+ { "role", "assistant" },
|
|
|
|
|
+ { "content", {} },
|
|
|
|
|
+ { "tool_calls", tool_calls }
|
|
|
};
|
|
};
|
|
|
- auto tool_call_message_with_id = json::parse(tool_call_message.dump());
|
|
|
|
|
- tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
|
|
|
|
|
|
|
|
|
|
auto python_tool_call_message = json{
|
|
auto python_tool_call_message = json{
|
|
|
{ "role", "assistant" },
|
|
{ "role", "assistant" },
|
|
@@ -322,6 +390,27 @@ static void test_template_output_parsers() {
|
|
|
inputs_tools_builtin.tools = json::array();
|
|
inputs_tools_builtin.tools = json::array();
|
|
|
inputs_tools_builtin.tools.push_back(python_tool);
|
|
inputs_tools_builtin.tools.push_back(python_tool);
|
|
|
|
|
|
|
|
|
|
+ {
|
|
|
|
|
+ // Not supported yet
|
|
|
|
|
+ const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
|
|
|
|
|
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
+ }
|
|
|
|
|
+ {
|
|
|
|
|
+ const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
|
|
|
|
+ std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
|
|
|
|
+
|
|
|
|
|
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
|
|
|
|
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
+
|
|
|
|
|
+ test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
|
|
|
|
+ "<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
|
|
|
|
+ "<|START_ACTION|>[\n"
|
|
|
|
|
+ " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
|
|
|
|
+ "]<|END_ACTION|>");
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools,
|
|
|
|
|
+ "<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
|
|
|
|
|
+ /* expect_grammar_triggered= */ false);
|
|
|
|
|
+ }
|
|
|
{
|
|
{
|
|
|
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
|
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
|
|
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
|
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
|
@@ -362,11 +451,10 @@ static void test_template_output_parsers() {
|
|
|
|
|
|
|
|
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
|
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(
|
|
test_template(
|
|
|
tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
|
- "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
|
|
|
|
|
- /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
|
|
}
|
|
}
|
|
|
{
|
|
{
|
|
|
const common_chat_template tmpl(
|
|
const common_chat_template tmpl(
|
|
@@ -388,7 +476,7 @@ static void test_template_output_parsers() {
|
|
|
inputs_tools)
|
|
inputs_tools)
|
|
|
.format);
|
|
.format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
"<tool_call>\n"
|
|
"<tool_call>\n"
|
|
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
@@ -413,7 +501,7 @@ static void test_template_output_parsers() {
|
|
|
inputs_tools_builtin)
|
|
inputs_tools_builtin)
|
|
|
.format);
|
|
.format);
|
|
|
|
|
|
|
|
- // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
|
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
|
|
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
|
@@ -428,7 +516,7 @@ static void test_template_output_parsers() {
|
|
|
|
|
|
|
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
|
}
|
|
}
|
|
@@ -440,7 +528,7 @@ static void test_template_output_parsers() {
|
|
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
|
common_chat_params_init(tmpl, inputs_tools).format);
|
|
common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
"<function=special_function>{\"arg1\": 1}</function>");
|
|
"<function=special_function>{\"arg1\": 1}</function>");
|
|
|
}
|
|
}
|
|
@@ -455,7 +543,7 @@ static void test_template_output_parsers() {
|
|
|
test_template(tmpl, end_tokens, text_message, {},
|
|
test_template(tmpl, end_tokens, text_message, {},
|
|
|
"all\n"
|
|
"all\n"
|
|
|
"Hello, world!",
|
|
"Hello, world!",
|
|
|
- /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
"special_function\n"
|
|
"special_function\n"
|
|
|
"{\"arg1\": 1}");
|
|
"{\"arg1\": 1}");
|
|
@@ -467,7 +555,7 @@ static void test_template_output_parsers() {
|
|
|
|
|
|
|
|
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
|
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
|
|
}
|
|
}
|
|
@@ -478,7 +566,7 @@ static void test_template_output_parsers() {
|
|
|
|
|
|
|
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
|
|
|
|
|
|
|
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
|
|
|
|
|
|
+ test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
|
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
|
"```json\n"
|
|
"```json\n"
|