test-chat.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <fstream>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat-template.hpp"
  13. #include "chat.hpp"
  14. #include "llama-grammar.h"
  15. #include "unicode.h"
  16. using json = nlohmann::ordered_json;
  17. static common_chat_msg msg_from_json(const json & message) {
  18. common_chat_msg ret{
  19. "assistant",
  20. "",
  21. {},
  22. /* .tool_plan = */ "",
  23. };
  24. if (message.contains("content") && !message.at("content").is_null()) {
  25. ret.content = message.at("content");
  26. }
  27. if (message.contains("tool_plan")) {
  28. ret.tool_plan = message.at("tool_plan");
  29. }
  30. auto has_tool_calls = message.contains("tool_calls");
  31. if (has_tool_calls) {
  32. for (const auto & tc : message.at("tool_calls")) {
  33. const auto & arguments = tc.at("function").at("arguments");
  34. ret.tool_calls.push_back({
  35. tc.at("function").at("name").get<std::string>(),
  36. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  37. tc.contains("id") ? tc.at("id").get<std::string>() : "",
  38. });
  39. }
  40. }
  41. return ret;
  42. }
  43. template <class T> static void assert_equals(const T & expected, const T & actual) {
  44. if (expected != actual) {
  45. std::cerr << "Expected: " << expected << std::endl;
  46. std::cerr << "Actual: " << actual << std::endl;
  47. std::cerr << std::flush;
  48. throw std::runtime_error("Test failed");
  49. }
  50. }
  51. static std::string read_file(const std::string & path) {
  52. std::cerr << "# Reading: " << path << std::endl << std::flush;
  53. std::ifstream fs(path, std::ios_base::binary);
  54. if (!fs.is_open()) {
  55. fs = std::ifstream("../" + path, std::ios_base::binary);
  56. if (!fs.is_open()) {
  57. throw std::runtime_error("Failed to open file: " + path);
  58. }
  59. }
  60. fs.seekg(0, std::ios_base::end);
  61. auto size = fs.tellg();
  62. fs.seekg(0);
  63. std::string out;
  64. out.resize(static_cast<size_t>(size));
  65. fs.read(&out[0], static_cast<std::streamsize>(size));
  66. return out;
  67. }
  68. static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
  69. return std::unique_ptr<llama_grammar>(
  70. llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
  71. }
  72. // TODO: extract to common helper (copied from test-grammar-integration.cpp)
  73. static bool match_string(const std::string & input, llama_grammar * grammar) {
  74. const auto cpts = unicode_cpts_from_utf8(input);
  75. auto & stacks_cur = llama_grammar_get_stacks(grammar);
  76. for (const auto & cpt : cpts) {
  77. llama_grammar_accept(grammar, cpt);
  78. if (stacks_cur.empty()) {
  79. // no stacks means that the grammar failed to match at this point
  80. return false;
  81. }
  82. }
  83. for (const auto & stack : stacks_cur) {
  84. if (stack.empty()) {
  85. // An empty stack means that the grammar has been completed
  86. return true;
  87. }
  88. }
  89. return false;
  90. }
  91. // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
  92. static std::string dump(const json & j) {
  93. return minja::Value(j).dump(-1, /* to_json= */ true);
  94. }
  95. static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
  96. assert_equals(expected.role, actual.role);
  97. assert_equals(expected.content, actual.content);
  98. assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
  99. for (size_t i = 0; i < expected.tool_calls.size(); i++) {
  100. const auto & expected_tool_call = expected.tool_calls[i];
  101. const auto & actual_tool_call = actual.tool_calls[i];
  102. assert_equals(expected_tool_call.name, actual_tool_call.name);
  103. assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
  104. assert_equals(expected_tool_call.id, actual_tool_call.id);
  105. }
  106. }
  107. const auto special_function_tool = json::parse(R"({
  108. "type": "function",
  109. "function": {
  110. "name": "special_function",
  111. "description": "I'm special",
  112. "parameters": {
  113. "type": "object",
  114. "properties": {
  115. "arg1": {
  116. "type": "integer",
  117. "description": "The arg."
  118. }
  119. },
  120. "required": ["arg1"]
  121. }
  122. }
  123. })");
  124. const auto python_tool = json::parse(R"({
  125. "type": "function",
  126. "function": {
  127. "name": "python",
  128. "description": "an ipython interpreter",
  129. "parameters": {
  130. "type": "object",
  131. "properties": {
  132. "code": {
  133. "type": "string",
  134. "description": "Python code to execute."
  135. }
  136. },
  137. "required": ["code"]
  138. }
  139. }
  140. })");
  141. const auto code_interpreter_tool = json::parse(R"({
  142. "type": "function",
  143. "function": {
  144. "name": "code_interpreter",
  145. "description": "an ipython interpreter",
  146. "parameters": {
  147. "type": "object",
  148. "properties": {
  149. "code": {
  150. "type": "string",
  151. "description": "Python code to execute."
  152. }
  153. },
  154. "required": ["code"]
  155. }
  156. }
  157. })");
  158. const json tools = { special_function_tool, python_tool };
  159. const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
  160. struct delta_data {
  161. std::string delta;
  162. common_chat_params params;
  163. };
  164. static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  165. const json & user_message, const json & delta_message, const json & tools,
  166. const json & tool_choice) {
  167. common_chat_inputs inputs;
  168. inputs.parallel_tool_calls = true;
  169. inputs.messages = json::array();
  170. inputs.messages.push_back(user_message);
  171. inputs.tools = tools;
  172. inputs.tool_choice = tool_choice;
  173. auto params_prefix = common_chat_params_init(tmpl, inputs);
  174. inputs.messages.push_back(delta_message);
  175. inputs.add_generation_prompt = false;
  176. auto params_full = common_chat_params_init(tmpl, inputs);
  177. std::string prefix = params_prefix.prompt;
  178. std::string full = params_full.prompt;
  179. // Check full starts with prefix
  180. if (full.find(prefix) != 0) {
  181. fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
  182. throw std::runtime_error("Full message does not start with prefix");
  183. }
  184. if (full == prefix) {
  185. throw std::runtime_error("Full message is the same as the prefix");
  186. }
  187. auto delta = full.substr(prefix.size());
  188. // Strip end tokens
  189. for (const auto & end_token : end_tokens) {
  190. // rfind to find the last occurrence
  191. auto pos = delta.rfind(end_token);
  192. if (pos != std::string::npos) {
  193. delta = delta.substr(0, pos);
  194. break;
  195. }
  196. }
  197. return { delta, params_full };
  198. }
  199. /*
  200. Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
  201. gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
  202. the parsed message is the same as the test_message
  203. */
  204. static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  205. const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
  206. bool expect_grammar_triggered = true) {
  207. common_chat_msg expected_msg = msg_from_json(test_message);
  208. auto user_message = json{
  209. { "role", "user" },
  210. { "content", "Hello, world!" }
  211. };
  212. for (const auto & tool_choice : json({ "auto", "required" })) {
  213. auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
  214. if (!expected_delta.empty()) {
  215. assert_equals(expected_delta, data.delta);
  216. }
  217. if (expect_grammar_triggered) {
  218. const auto msg = common_chat_parse(data.delta, data.params.format);
  219. assert_msg_equals(expected_msg, msg);
  220. }
  221. if (!expected_msg.tool_calls.empty()) {
  222. GGML_ASSERT(!data.params.grammar.empty());
  223. }
  224. if (!data.params.grammar.empty()) {
  225. auto grammar = build_grammar(data.params.grammar);
  226. if (!grammar) {
  227. throw std::runtime_error("Failed to build grammar");
  228. }
  229. auto earliest_trigger_pos = std::string::npos;
  230. auto constrained = data.delta;
  231. for (const auto & trigger : data.params.grammar_triggers) {
  232. auto pos = constrained.find(trigger.word);
  233. if (pos == std::string::npos) {
  234. continue;
  235. }
  236. if (pos > 0 && trigger.at_start) {
  237. fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
  238. continue;
  239. }
  240. if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
  241. earliest_trigger_pos = pos;
  242. }
  243. }
  244. auto grammar_triggered = false;
  245. if (earliest_trigger_pos != std::string::npos) {
  246. constrained = constrained.substr(earliest_trigger_pos);
  247. grammar_triggered = true;
  248. }
  249. if (data.params.grammar_lazy) {
  250. assert_equals(expect_grammar_triggered, grammar_triggered);
  251. }
  252. if (grammar_triggered && !match_string(constrained, grammar.get())) {
  253. throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
  254. "\n\nGrammar: " + data.params.grammar);
  255. }
  256. }
  257. }
  258. }
  259. static void test_template_output_parsers() {
  260. json text_message {
  261. { "role", "assistant" },
  262. { "content", "Hello, world!" },
  263. };
  264. json tool_calls = json::array({{
  265. { "type", "function" },
  266. { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
  267. }});
  268. json tool_call_message {
  269. { "role", "assistant"},
  270. { "content", {}},
  271. { "tool_calls", {
  272. {
  273. { "type", "function" },
  274. { "function", {
  275. { "name", "special_function" },
  276. { "arguments", "{\"arg1\": 1}" },
  277. }},
  278. },
  279. }},
  280. };
  281. json tool_call_message_with_id {
  282. { "role", "assistant"},
  283. { "content", {}},
  284. { "tool_calls", {
  285. {
  286. { "type", "function" },
  287. { "function", {
  288. { "name", "special_function" },
  289. { "arguments", "{\"arg1\": 1}" },
  290. }},
  291. {"id", "123456789"},
  292. },
  293. }},
  294. { "role", "assistant" },
  295. { "content", {} },
  296. { "tool_calls", tool_calls }
  297. };
  298. json tool_call_plan_message_with_idx {
  299. { "role", "assistant"},
  300. { "content", {}},
  301. { "tool_plan", "I'm not so sure"},
  302. { "tool_calls", {
  303. {
  304. { "type", "function" },
  305. { "function", {
  306. { "name", "special_function" },
  307. { "arguments", "{\"arg1\": 1}" },
  308. }},
  309. // Index of the tool call in the tool_calls array
  310. {"id", "0"},
  311. },
  312. }},
  313. { "role", "assistant" },
  314. { "content", {} },
  315. { "tool_calls", tool_calls }
  316. };
  317. auto python_tool_call_message = json{
  318. { "role", "assistant" },
  319. { "content", {} },
  320. { "tool_calls", json{ {
  321. { "type", "function" },
  322. { "function",
  323. {
  324. { "name", "python" },
  325. { "arguments",
  326. {
  327. { "code", "print('hey')" },
  328. } },
  329. } },
  330. } } }
  331. };
  332. auto code_interpreter_tool_call_message = json{
  333. { "role", "assistant" },
  334. { "content", {} },
  335. { "tool_calls", json{ {
  336. { "type", "function" },
  337. { "function",
  338. {
  339. { "name", "code_interpreter" },
  340. { "arguments",
  341. {
  342. { "code", "print('hey')" },
  343. } },
  344. } },
  345. } } }
  346. };
  347. common_chat_inputs inputs_no_tools;
  348. inputs_no_tools.messages = {
  349. { { "role", "user" }, { "content", "Hey" } }
  350. };
  351. common_chat_inputs inputs_tools = inputs_no_tools;
  352. inputs_tools.tools = json::array();
  353. inputs_tools.tools.push_back(special_function_tool);
  354. common_chat_inputs inputs_tools_builtin = inputs_no_tools;
  355. inputs_tools_builtin.tools = json::array();
  356. inputs_tools_builtin.tools.push_back(python_tool);
  357. {
  358. // Not supported yet
  359. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
  360. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  361. }
  362. {
  363. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
  364. std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
  365. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  366. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
  367. test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
  368. "<|START_THINKING|>I'm not so sure<|END_THINKING|>"
  369. "<|START_ACTION|>[\n"
  370. " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
  371. "]<|END_ACTION|>");
  372. test_template(tmpl, end_tokens, text_message, tools,
  373. "<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
  374. /* expect_grammar_triggered= */ false);
  375. }
  376. {
  377. const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
  378. std::vector<std::string> end_tokens{ "<end_of_turn>" };
  379. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  380. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  381. assert_equals(COMMON_CHAT_FORMAT_GENERIC,
  382. common_chat_params_init(
  383. common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
  384. "<s>", "</s>"),
  385. inputs_tools)
  386. .format);
  387. // Generic tool calls doesn't generate / parse content-only messages symmetrically.
  388. assert_msg_equals(msg_from_json(text_message),
  389. common_chat_parse("{\n"
  390. " \"response\": \"Hello, world!\"\n"
  391. "}",
  392. common_chat_params_init(tmpl, inputs_tools).format));
  393. test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
  394. "{\n"
  395. " \"tool_calls\": [\n"
  396. " {\n"
  397. " \"name\": \"special_function\",\n"
  398. " \"arguments\": {\n"
  399. " \"arg1\": 1\n"
  400. " },\n"
  401. " \"id\": \"123456789\"\n"
  402. " }\n"
  403. " ]\n"
  404. "}");
  405. }
  406. {
  407. const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
  408. "</s>");
  409. std::vector<std::string> end_tokens{ "</s>" };
  410. assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
  411. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  412. test_template(
  413. tmpl, end_tokens, tool_call_message_with_id, tools,
  414. "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
  415. }
  416. {
  417. const common_chat_template tmpl(
  418. read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
  419. std::vector<std::string> end_tokens{ "<|im_end|>" };
  420. assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
  421. assert_equals(
  422. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  423. common_chat_params_init(
  424. common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
  425. "<s>", "</s>"),
  426. inputs_tools)
  427. .format);
  428. assert_equals(
  429. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  430. common_chat_params_init(
  431. common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
  432. inputs_tools)
  433. .format);
  434. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  435. test_template(tmpl, end_tokens, tool_call_message, tools,
  436. "<tool_call>\n"
  437. "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
  438. "</tool_call>");
  439. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  440. "<tool_call>\n"
  441. "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
  442. "</tool_call>");
  443. }
  444. {
  445. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
  446. "</s>");
  447. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  448. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  449. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  450. common_chat_params_init(tmpl, inputs_tools_builtin).format);
  451. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  452. common_chat_params_init(
  453. common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
  454. "<s>", "</s>"),
  455. inputs_tools_builtin)
  456. .format);
  457. // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
  458. test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
  459. "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
  460. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  461. "<|python_tag|>python.call(code=\"print('hey')\")");
  462. test_template(tmpl, end_tokens, tool_call_message, tools,
  463. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  464. }
  465. {
  466. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
  467. "</s>");
  468. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  469. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  470. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  471. test_template(tmpl, end_tokens, tool_call_message, tools,
  472. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  473. }
  474. {
  475. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
  476. "</s>");
  477. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  478. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  479. common_chat_params_init(tmpl, inputs_tools).format);
  480. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  481. test_template(tmpl, end_tokens, tool_call_message, tools,
  482. "<function=special_function>{\"arg1\": 1}</function>");
  483. }
  484. {
  485. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
  486. "</s>");
  487. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  488. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
  489. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
  490. test_template(tmpl, end_tokens, text_message, {},
  491. "all\n"
  492. "Hello, world!",
  493. /* expect_grammar_triggered= */ false);
  494. test_template(tmpl, end_tokens, tool_call_message, tools,
  495. "special_function\n"
  496. "{\"arg1\": 1}");
  497. }
  498. {
  499. const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
  500. "</s>");
  501. std::vector<std::string> end_tokens{ "<|eot_id|>" };
  502. assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
  503. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  504. test_template(tmpl, end_tokens, tool_call_message, tools,
  505. " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
  506. }
  507. {
  508. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
  509. "<s>", "</s>");
  510. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  511. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
  512. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
  513. test_template(tmpl, end_tokens, tool_call_message, tools,
  514. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  515. "```json\n"
  516. "{\"arg1\": 1}\n"
  517. "```<|tool▁call▁end|>");
  518. }
  519. }
  520. int main(int argc, char ** argv) {
  521. #ifndef _WIN32
  522. if (argc > 1) {
  523. common_chat_inputs inputs;
  524. inputs.messages = {
  525. { { "role", "user" }, { "content", "Hey" } }
  526. };
  527. inputs.tools = json::array({ special_function_tool });
  528. std::cout << "| Template | Format |\n";
  529. std::cout << "|----------|--------|\n";
  530. for (int i = 1; i < argc; i++) {
  531. std::string path = argv[i];
  532. if (path.rfind(".jinja") != path.size() - 6) {
  533. std::cerr << "Skipping non-jinja file: " << path << std::endl;
  534. continue;
  535. }
  536. common_chat_template tmpl(read_file(path), "", "");
  537. auto parts = string_split(path, "/");
  538. auto name = parts[parts.size() - 1];
  539. std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
  540. << " |\n";
  541. }
  542. } else
  543. #endif
  544. {
  545. test_template_output_parsers();
  546. std::cout << "\n[chat] All tests passed!" << std::endl;
  547. }
  548. return 0;
  549. }