test-chat.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <fstream>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat-template.hpp"
  13. #include "chat.hpp"
  14. #include "llama-grammar.h"
  15. #include "unicode.h"
  16. using json = nlohmann::ordered_json;
  17. static common_chat_msg msg_from_json(const json & message) {
  18. common_chat_msg ret;
  19. ret.role = "assistant";
  20. if (message.contains("content") && !message.at("content").is_null()) {
  21. ret.content = message.at("content");
  22. }
  23. if (message.contains("tool_plan")) {
  24. ret.tool_plan = message.at("tool_plan");
  25. }
  26. auto has_tool_calls = message.contains("tool_calls");
  27. if (has_tool_calls) {
  28. for (const auto & tc : message.at("tool_calls")) {
  29. const auto & arguments = tc.at("function").at("arguments");
  30. ret.tool_calls.push_back({
  31. tc.at("function").at("name").get<std::string>(),
  32. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  33. tc.contains("id") ? tc.at("id").get<std::string>() : "",
  34. });
  35. }
  36. }
  37. return ret;
  38. }
  39. template <class T> static void assert_equals(const T & expected, const T & actual) {
  40. if (expected != actual) {
  41. std::cerr << "Expected: " << expected << std::endl;
  42. std::cerr << "Actual: " << actual << std::endl;
  43. std::cerr << std::flush;
  44. throw std::runtime_error("Test failed");
  45. }
  46. }
  47. static std::string read_file(const std::string & path) {
  48. std::cerr << "# Reading: " << path << std::endl << std::flush;
  49. std::ifstream fs(path, std::ios_base::binary);
  50. if (!fs.is_open()) {
  51. fs = std::ifstream("../" + path, std::ios_base::binary);
  52. if (!fs.is_open()) {
  53. throw std::runtime_error("Failed to open file: " + path);
  54. }
  55. }
  56. fs.seekg(0, std::ios_base::end);
  57. auto size = fs.tellg();
  58. fs.seekg(0);
  59. std::string out;
  60. out.resize(static_cast<size_t>(size));
  61. fs.read(&out[0], static_cast<std::streamsize>(size));
  62. return out;
  63. }
  64. static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
  65. return std::unique_ptr<llama_grammar>(
  66. llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
  67. }
  68. // TODO: extract to common helper (copied from test-grammar-integration.cpp)
  69. static bool match_string(const std::string & input, llama_grammar * grammar) {
  70. const auto cpts = unicode_cpts_from_utf8(input);
  71. auto & stacks_cur = llama_grammar_get_stacks(grammar);
  72. for (const auto & cpt : cpts) {
  73. llama_grammar_accept(grammar, cpt);
  74. if (stacks_cur.empty()) {
  75. // no stacks means that the grammar failed to match at this point
  76. return false;
  77. }
  78. }
  79. for (const auto & stack : stacks_cur) {
  80. if (stack.empty()) {
  81. // An empty stack means that the grammar has been completed
  82. return true;
  83. }
  84. }
  85. return false;
  86. }
  87. // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
  88. static std::string dump(const json & j) {
  89. return minja::Value(j).dump(-1, /* to_json= */ true);
  90. }
  91. static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
  92. assert_equals(expected.role, actual.role);
  93. assert_equals(expected.content, actual.content);
  94. assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
  95. for (size_t i = 0; i < expected.tool_calls.size(); i++) {
  96. const auto & expected_tool_call = expected.tool_calls[i];
  97. const auto & actual_tool_call = actual.tool_calls[i];
  98. assert_equals(expected_tool_call.name, actual_tool_call.name);
  99. assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
  100. assert_equals(expected_tool_call.id, actual_tool_call.id);
  101. }
  102. }
  103. const auto special_function_tool = json::parse(R"({
  104. "type": "function",
  105. "function": {
  106. "name": "special_function",
  107. "description": "I'm special",
  108. "parameters": {
  109. "type": "object",
  110. "properties": {
  111. "arg1": {
  112. "type": "integer",
  113. "description": "The arg."
  114. }
  115. },
  116. "required": ["arg1"]
  117. }
  118. }
  119. })");
  120. const auto python_tool = json::parse(R"({
  121. "type": "function",
  122. "function": {
  123. "name": "python",
  124. "description": "an ipython interpreter",
  125. "parameters": {
  126. "type": "object",
  127. "properties": {
  128. "code": {
  129. "type": "string",
  130. "description": "Python code to execute."
  131. }
  132. },
  133. "required": ["code"]
  134. }
  135. }
  136. })");
  137. const auto code_interpreter_tool = json::parse(R"({
  138. "type": "function",
  139. "function": {
  140. "name": "code_interpreter",
  141. "description": "an ipython interpreter",
  142. "parameters": {
  143. "type": "object",
  144. "properties": {
  145. "code": {
  146. "type": "string",
  147. "description": "Python code to execute."
  148. }
  149. },
  150. "required": ["code"]
  151. }
  152. }
  153. })");
  154. const json tools = { special_function_tool, python_tool };
  155. const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
  156. struct delta_data {
  157. std::string delta;
  158. common_chat_params params;
  159. };
  160. static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  161. const json & user_message, const json & delta_message, const json & tools,
  162. const json & tool_choice) {
  163. common_chat_inputs inputs;
  164. inputs.parallel_tool_calls = true;
  165. inputs.messages = json::array();
  166. inputs.messages.push_back(user_message);
  167. inputs.tools = tools;
  168. inputs.tool_choice = tool_choice;
  169. auto params_prefix = common_chat_params_init(tmpl, inputs);
  170. inputs.messages.push_back(delta_message);
  171. inputs.add_generation_prompt = false;
  172. auto params_full = common_chat_params_init(tmpl, inputs);
  173. std::string prefix = params_prefix.prompt;
  174. std::string full = params_full.prompt;
  175. // Check full starts with prefix
  176. if (full.find(prefix) != 0) {
  177. fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
  178. throw std::runtime_error("Full message does not start with prefix");
  179. }
  180. if (full == prefix) {
  181. throw std::runtime_error("Full message is the same as the prefix");
  182. }
  183. auto delta = full.substr(prefix.size());
  184. // Strip end tokens
  185. for (const auto & end_token : end_tokens) {
  186. // rfind to find the last occurrence
  187. auto pos = delta.rfind(end_token);
  188. if (pos != std::string::npos) {
  189. delta = delta.substr(0, pos);
  190. break;
  191. }
  192. }
  193. return { delta, params_full };
  194. }
  195. /*
  196. Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
  197. gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
  198. the parsed message is the same as the test_message
  199. */
  200. static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  201. const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
  202. bool expect_grammar_triggered = true) {
  203. common_chat_msg expected_msg = msg_from_json(test_message);
  204. auto user_message = json{
  205. { "role", "user" },
  206. { "content", "Hello, world!" }
  207. };
  208. for (const auto & tool_choice : json({ "auto", "required" })) {
  209. auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
  210. if (!expected_delta.empty()) {
  211. assert_equals(expected_delta, data.delta);
  212. }
  213. if (expect_grammar_triggered) {
  214. const auto msg = common_chat_parse(data.delta, data.params.format);
  215. assert_msg_equals(expected_msg, msg);
  216. }
  217. if (!expected_msg.tool_calls.empty()) {
  218. GGML_ASSERT(!data.params.grammar.empty());
  219. }
  220. if (!data.params.grammar.empty()) {
  221. auto grammar = build_grammar(data.params.grammar);
  222. if (!grammar) {
  223. throw std::runtime_error("Failed to build grammar");
  224. }
  225. auto earliest_trigger_pos = std::string::npos;
  226. auto constrained = data.delta;
  227. for (const auto & trigger : data.params.grammar_triggers) {
  228. auto pos = constrained.find(trigger.word);
  229. if (pos == std::string::npos) {
  230. continue;
  231. }
  232. if (pos > 0 && trigger.at_start) {
  233. fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
  234. continue;
  235. }
  236. if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
  237. earliest_trigger_pos = pos;
  238. }
  239. }
  240. auto grammar_triggered = false;
  241. if (earliest_trigger_pos != std::string::npos) {
  242. constrained = constrained.substr(earliest_trigger_pos);
  243. grammar_triggered = true;
  244. }
  245. if (data.params.grammar_lazy) {
  246. assert_equals(expect_grammar_triggered, grammar_triggered);
  247. }
  248. if (grammar_triggered && !match_string(constrained, grammar.get())) {
  249. throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
  250. "\n\nGrammar: " + data.params.grammar);
  251. }
  252. }
  253. }
  254. }
  255. static void test_template_output_parsers() {
  256. json text_message {
  257. { "role", "assistant" },
  258. { "content", "Hello, world!\nWhat's up?" },
  259. };
  260. json tool_calls = json::array({{
  261. { "type", "function" },
  262. { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
  263. }});
  264. json tool_call_message {
  265. { "role", "assistant"},
  266. { "content", {}},
  267. { "tool_calls", {
  268. {
  269. { "type", "function" },
  270. { "function", {
  271. { "name", "special_function" },
  272. { "arguments", "{\"arg1\": 1}" },
  273. }},
  274. },
  275. }},
  276. };
  277. json tool_call_message_with_id {
  278. { "role", "assistant"},
  279. { "content", {}},
  280. { "tool_calls", {
  281. {
  282. { "type", "function" },
  283. { "function", {
  284. { "name", "special_function" },
  285. { "arguments", "{\"arg1\": 1}" },
  286. }},
  287. {"id", "123456789"},
  288. },
  289. }},
  290. { "role", "assistant" },
  291. { "content", {} },
  292. { "tool_calls", tool_calls }
  293. };
  294. json tool_call_plan_message_with_idx {
  295. { "role", "assistant"},
  296. { "content", {}},
  297. { "tool_plan", "I'm not so sure"},
  298. { "tool_calls", {
  299. {
  300. { "type", "function" },
  301. { "function", {
  302. { "name", "special_function" },
  303. { "arguments", "{\"arg1\": 1}" },
  304. }},
  305. // Index of the tool call in the tool_calls array
  306. {"id", "0"},
  307. },
  308. }},
  309. { "role", "assistant" },
  310. { "content", {} },
  311. { "tool_calls", tool_calls }
  312. };
  313. auto python_tool_call_message = json{
  314. { "role", "assistant" },
  315. { "content", {} },
  316. { "tool_calls", json{ {
  317. { "type", "function" },
  318. { "function",
  319. {
  320. { "name", "python" },
  321. { "arguments",
  322. {
  323. { "code", "print('hey')" },
  324. } },
  325. } },
  326. } } }
  327. };
  328. auto code_interpreter_tool_call_message = json{
  329. { "role", "assistant" },
  330. { "content", {} },
  331. { "tool_calls", json{ {
  332. { "type", "function" },
  333. { "function",
  334. {
  335. { "name", "code_interpreter" },
  336. { "arguments",
  337. {
  338. { "code", "print('hey')" },
  339. } },
  340. } },
  341. } } }
  342. };
  343. common_chat_inputs inputs_no_tools;
  344. inputs_no_tools.messages = {
  345. { { "role", "user" }, { "content", "Hey\nThere" } }
  346. };
  347. common_chat_inputs inputs_tools = inputs_no_tools;
  348. inputs_tools.tools = json::array();
  349. inputs_tools.tools.push_back(special_function_tool);
  350. common_chat_inputs inputs_tools_builtin = inputs_no_tools;
  351. inputs_tools_builtin.tools = json::array();
  352. inputs_tools_builtin.tools.push_back(python_tool);
  353. {
  354. // Not supported yet
  355. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
  356. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  357. }
  358. {
  359. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
  360. std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
  361. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  362. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
  363. test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
  364. "<|START_THINKING|>I'm not so sure<|END_THINKING|>"
  365. "<|START_ACTION|>[\n"
  366. " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
  367. "]<|END_ACTION|>");
  368. test_template(tmpl, end_tokens, text_message, tools,
  369. "<|START_RESPONSE|>Hello, world!\n"
  370. "What's up?<|END_RESPONSE|>",
  371. /* expect_grammar_triggered= */ false);
  372. }
  373. {
  374. const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
  375. std::vector<std::string> end_tokens{ "<end_of_turn>" };
  376. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  377. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  378. assert_equals(COMMON_CHAT_FORMAT_GENERIC,
  379. common_chat_params_init(
  380. common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
  381. "<s>", "</s>"),
  382. inputs_tools)
  383. .format);
  384. // Generic tool calls doesn't generate / parse content-only messages symmetrically.
  385. assert_msg_equals(msg_from_json(text_message),
  386. common_chat_parse("{\n"
  387. " \"response\": \"Hello, world!\\nWhat's up?\"\n"
  388. "}",
  389. common_chat_params_init(tmpl, inputs_tools).format));
  390. test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
  391. "{\n"
  392. " \"tool_calls\": [\n"
  393. " {\n"
  394. " \"name\": \"special_function\",\n"
  395. " \"arguments\": {\n"
  396. " \"arg1\": 1\n"
  397. " },\n"
  398. " \"id\": \"123456789\"\n"
  399. " }\n"
  400. " ]\n"
  401. "}");
  402. }
  403. {
  404. const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
  405. "</s>");
  406. std::vector<std::string> end_tokens{ "</s>" };
  407. assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
  408. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  409. test_template(
  410. tmpl, end_tokens, tool_call_message_with_id, tools,
  411. "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
  412. }
  413. {
  414. const common_chat_template tmpl(
  415. read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
  416. std::vector<std::string> end_tokens{ "<|im_end|>" };
  417. assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
  418. assert_equals(
  419. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  420. common_chat_params_init(
  421. common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
  422. "<s>", "</s>"),
  423. inputs_tools)
  424. .format);
  425. assert_equals(
  426. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  427. common_chat_params_init(
  428. common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
  429. inputs_tools)
  430. .format);
  431. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  432. test_template(tmpl, end_tokens, tool_call_message, tools,
  433. "<tool_call>\n"
  434. "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
  435. "</tool_call>");
  436. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  437. "<tool_call>\n"
  438. "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
  439. "</tool_call>");
  440. }
  441. {
  442. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
  443. "</s>");
  444. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  445. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  446. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  447. common_chat_params_init(tmpl, inputs_tools_builtin).format);
  448. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  449. common_chat_params_init(
  450. common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
  451. "<s>", "</s>"),
  452. inputs_tools_builtin)
  453. .format);
  454. // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
  455. test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
  456. "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
  457. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  458. "<|python_tag|>python.call(code=\"print('hey')\")");
  459. test_template(tmpl, end_tokens, tool_call_message, tools,
  460. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  461. }
  462. {
  463. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
  464. "</s>");
  465. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  466. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  467. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  468. test_template(tmpl, end_tokens, tool_call_message, tools,
  469. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  470. }
  471. {
  472. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
  473. "</s>");
  474. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  475. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  476. common_chat_params_init(tmpl, inputs_tools).format);
  477. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  478. test_template(tmpl, end_tokens, tool_call_message, tools,
  479. "<function=special_function>{\"arg1\": 1}</function>");
  480. }
  481. {
  482. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
  483. "</s>");
  484. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  485. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
  486. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
  487. test_template(tmpl, end_tokens, text_message, {},
  488. "all\n"
  489. "Hello, world!\n"
  490. "What's up?",
  491. /* expect_grammar_triggered= */ false);
  492. test_template(tmpl, end_tokens, tool_call_message, tools,
  493. "special_function\n"
  494. "{\"arg1\": 1}");
  495. }
  496. {
  497. const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
  498. "</s>");
  499. std::vector<std::string> end_tokens{ "<|eot_id|>" };
  500. assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
  501. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  502. test_template(tmpl, end_tokens, tool_call_message, tools,
  503. " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
  504. }
  505. {
  506. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
  507. "<s>", "</s>");
  508. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  509. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
  510. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  511. test_template(tmpl, end_tokens, tool_call_message, tools,
  512. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  513. "```json\n"
  514. "{\"arg1\": 1}\n"
  515. "```<|tool▁call▁end|>");
  516. }
  517. }
  518. int main(int argc, char ** argv) {
  519. #ifndef _WIN32
  520. if (argc > 1) {
  521. common_chat_inputs inputs;
  522. inputs.messages = {
  523. { { "role", "user" }, { "content", "Hey" } }
  524. };
  525. inputs.tools = json::array({ special_function_tool });
  526. std::cout << "| Template | Format |\n";
  527. std::cout << "|----------|--------|\n";
  528. for (int i = 1; i < argc; i++) {
  529. std::string path = argv[i];
  530. if (path.rfind(".jinja") != path.size() - 6) {
  531. std::cerr << "Skipping non-jinja file: " << path << std::endl;
  532. continue;
  533. }
  534. common_chat_template tmpl(read_file(path), "", "");
  535. auto parts = string_split(path, "/");
  536. auto name = parts[parts.size() - 1];
  537. std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
  538. << " |\n";
  539. }
  540. } else
  541. #endif
  542. {
  543. test_template_output_parsers();
  544. std::cout << "\n[chat] All tests passed!" << std::endl;
  545. }
  546. return 0;
  547. }