test-chat.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <fstream>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat-template.hpp"
  13. #include "chat.hpp"
  14. #include "llama-grammar.h"
  15. #include "unicode.h"
  16. using json = nlohmann::ordered_json;
  17. static common_chat_msg msg_from_json(const json & message) {
  18. common_chat_msg ret;
  19. ret.role = "assistant";
  20. if (message.contains("content") && !message.at("content").is_null()) {
  21. ret.content = message.at("content");
  22. }
  23. if (message.contains("tool_plan")) {
  24. ret.reasoning_content = message.at("tool_plan");
  25. }
  26. if (message.contains("reasoning_content")) {
  27. ret.reasoning_content = message.at("reasoning_content");
  28. }
  29. auto has_tool_calls = message.contains("tool_calls");
  30. if (has_tool_calls) {
  31. for (const auto & tc : message.at("tool_calls")) {
  32. const auto & arguments = tc.at("function").at("arguments");
  33. ret.tool_calls.push_back({
  34. tc.at("function").at("name").get<std::string>(),
  35. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  36. tc.contains("id") ? tc.at("id").get<std::string>() : "",
  37. });
  38. }
  39. }
  40. return ret;
  41. }
  42. template <class T> static void assert_equals(const T & expected, const T & actual) {
  43. if (expected != actual) {
  44. std::cerr << "Expected: " << expected << std::endl;
  45. std::cerr << "Actual: " << actual << std::endl;
  46. std::cerr << std::flush;
  47. throw std::runtime_error("Test failed");
  48. }
  49. }
  50. static std::string read_file(const std::string & path) {
  51. std::cerr << "# Reading: " << path << std::endl << std::flush;
  52. std::ifstream fs(path, std::ios_base::binary);
  53. if (!fs.is_open()) {
  54. fs = std::ifstream("../" + path, std::ios_base::binary);
  55. if (!fs.is_open()) {
  56. throw std::runtime_error("Failed to open file: " + path);
  57. }
  58. }
  59. fs.seekg(0, std::ios_base::end);
  60. auto size = fs.tellg();
  61. fs.seekg(0);
  62. std::string out;
  63. out.resize(static_cast<size_t>(size));
  64. fs.read(&out[0], static_cast<std::streamsize>(size));
  65. return out;
  66. }
  67. static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
  68. return std::unique_ptr<llama_grammar>(
  69. llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
  70. }
  71. // TODO: extract to common helper (copied from test-grammar-integration.cpp)
  72. static bool match_string(const std::string & input, llama_grammar * grammar) {
  73. const auto cpts = unicode_cpts_from_utf8(input);
  74. auto & stacks_cur = llama_grammar_get_stacks(grammar);
  75. for (const auto & cpt : cpts) {
  76. llama_grammar_accept(grammar, cpt);
  77. if (stacks_cur.empty()) {
  78. // no stacks means that the grammar failed to match at this point
  79. return false;
  80. }
  81. }
  82. for (const auto & stack : stacks_cur) {
  83. if (stack.empty()) {
  84. // An empty stack means that the grammar has been completed
  85. return true;
  86. }
  87. }
  88. return false;
  89. }
  90. // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
  91. static std::string dump(const json & j) {
  92. return minja::Value(j).dump(-1, /* to_json= */ true);
  93. }
  94. static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
  95. assert_equals(expected.role, actual.role);
  96. assert_equals(expected.content, actual.content);
  97. assert_equals(expected.reasoning_content, actual.reasoning_content);
  98. assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
  99. for (size_t i = 0; i < expected.tool_calls.size(); i++) {
  100. const auto & expected_tool_call = expected.tool_calls[i];
  101. const auto & actual_tool_call = actual.tool_calls[i];
  102. assert_equals(expected_tool_call.name, actual_tool_call.name);
  103. assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
  104. assert_equals(expected_tool_call.id, actual_tool_call.id);
  105. }
  106. }
  107. const auto special_function_tool = json::parse(R"({
  108. "type": "function",
  109. "function": {
  110. "name": "special_function",
  111. "description": "I'm special",
  112. "parameters": {
  113. "type": "object",
  114. "properties": {
  115. "arg1": {
  116. "type": "integer",
  117. "description": "The arg."
  118. }
  119. },
  120. "required": ["arg1"]
  121. }
  122. }
  123. })");
  124. const auto python_tool = json::parse(R"({
  125. "type": "function",
  126. "function": {
  127. "name": "python",
  128. "description": "an ipython interpreter",
  129. "parameters": {
  130. "type": "object",
  131. "properties": {
  132. "code": {
  133. "type": "string",
  134. "description": "Python code to execute."
  135. }
  136. },
  137. "required": ["code"]
  138. }
  139. }
  140. })");
  141. const auto code_interpreter_tool = json::parse(R"({
  142. "type": "function",
  143. "function": {
  144. "name": "code_interpreter",
  145. "description": "an ipython interpreter",
  146. "parameters": {
  147. "type": "object",
  148. "properties": {
  149. "code": {
  150. "type": "string",
  151. "description": "Python code to execute."
  152. }
  153. },
  154. "required": ["code"]
  155. }
  156. }
  157. })");
  158. const json tools = { special_function_tool, python_tool };
  159. const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
  160. struct delta_data {
  161. std::string delta;
  162. common_chat_params params;
  163. };
  164. static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  165. const json & user_message, const json & delta_message, const json & tools,
  166. const json & tool_choice,
  167. bool think = false) {
  168. common_chat_inputs inputs;
  169. inputs.parallel_tool_calls = true;
  170. inputs.messages = json::array();
  171. inputs.messages.push_back(user_message);
  172. inputs.tools = tools;
  173. inputs.tool_choice = tool_choice;
  174. inputs.extract_reasoning = think;
  175. auto params_prefix = common_chat_params_init(tmpl, inputs);
  176. inputs.messages.push_back(delta_message);
  177. inputs.add_generation_prompt = false;
  178. auto params_full = common_chat_params_init(tmpl, inputs);
  179. std::string prefix = params_prefix.prompt;
  180. std::string full = params_full.prompt;
  181. if (full == prefix) {
  182. throw std::runtime_error("Full message is the same as the prefix");
  183. }
  184. size_t common_prefix_length = 0;
  185. for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
  186. if (prefix[i] != full[i]) {
  187. break;
  188. }
  189. if (prefix[i] == '<') {
  190. // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
  191. // but it removes thinking tags for past messages.
  192. // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
  193. continue;
  194. }
  195. common_prefix_length = i + 1;
  196. }
  197. auto delta = full.substr(common_prefix_length);
  198. // Strip end tokens
  199. for (const auto & end_token : end_tokens) {
  200. // rfind to find the last occurrence
  201. auto pos = delta.rfind(end_token);
  202. if (pos != std::string::npos) {
  203. delta = delta.substr(0, pos);
  204. break;
  205. }
  206. }
  207. return { delta, params_full };
  208. }
  209. /*
  210. Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
  211. gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
  212. the parsed message is the same as the test_message
  213. */
  214. static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  215. const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
  216. bool expect_grammar_triggered = true,
  217. bool test_grammar_if_triggered = true,
  218. bool think = false) {
  219. common_chat_msg expected_msg = msg_from_json(test_message);
  220. auto user_message = json{
  221. { "role", "user" },
  222. { "content", "Hello, world!" }
  223. };
  224. for (const auto & tool_choice : json({ "auto", "required" })) {
  225. auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
  226. if (!expected_delta.empty()) {
  227. assert_equals(expected_delta, data.delta);
  228. }
  229. if (expect_grammar_triggered) {
  230. const auto msg = common_chat_parse(data.delta, data.params.format);
  231. assert_msg_equals(expected_msg, msg);
  232. }
  233. if (!expected_msg.tool_calls.empty()) {
  234. GGML_ASSERT(!data.params.grammar.empty());
  235. }
  236. if (!data.params.grammar.empty()) {
  237. auto grammar = build_grammar(data.params.grammar);
  238. if (!grammar) {
  239. throw std::runtime_error("Failed to build grammar");
  240. }
  241. auto earliest_trigger_pos = std::string::npos;
  242. auto constrained = data.delta;
  243. for (const auto & trigger : data.params.grammar_triggers) {
  244. auto pos = constrained.find(trigger.word);
  245. if (pos == std::string::npos) {
  246. continue;
  247. }
  248. if (pos > 0 && trigger.at_start) {
  249. fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
  250. continue;
  251. }
  252. if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
  253. earliest_trigger_pos = pos;
  254. }
  255. }
  256. auto grammar_triggered = false;
  257. if (earliest_trigger_pos != std::string::npos) {
  258. constrained = constrained.substr(earliest_trigger_pos);
  259. grammar_triggered = true;
  260. }
  261. if (data.params.grammar_lazy) {
  262. assert_equals(expect_grammar_triggered, grammar_triggered);
  263. }
  264. if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
  265. throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
  266. "\n\nGrammar: " + data.params.grammar);
  267. }
  268. }
  269. }
  270. }
  271. static void test_template_output_parsers() {
  272. json message_user {
  273. { "role", "user" },
  274. { "content", "Hey there!" },
  275. };
  276. json message_assist {
  277. { "role", "assistant" },
  278. { "content", "Hello, world!\nWhat's up?" },
  279. };
  280. json message_assist_thoughts_unparsed_think {
  281. { "role", "assistant" },
  282. { "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
  283. };
  284. json message_assist_thoughts_unparsed_r7b {
  285. { "role", "assistant" },
  286. { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
  287. };
  288. json message_assist_thoughts {
  289. { "role", "assistant" },
  290. { "content", "Hello, world!\nWhat's up?" },
  291. { "reasoning_content", "I'm thinking" },
  292. };
  293. json tool_calls = json::array({{
  294. { "type", "function" },
  295. { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
  296. }});
  297. json message_assist_call {
  298. { "role", "assistant"},
  299. { "content", {}},
  300. { "tool_calls", {
  301. {
  302. { "type", "function" },
  303. { "function", {
  304. { "name", "special_function" },
  305. { "arguments", "{\"arg1\": 1}" },
  306. }},
  307. },
  308. }},
  309. };
  310. json message_assist_call_thoughts = {
  311. { "role", "assistant" },
  312. { "content", nullptr },
  313. { "reasoning_content", "I'm\nthinking" },
  314. { "tool_calls", {
  315. {
  316. { "type", "function" },
  317. { "function", {
  318. { "name", "special_function" },
  319. { "arguments", "{\"arg1\": 1}" },
  320. }},
  321. },
  322. }},
  323. };
  324. json message_assist_call_thoughts_unparsed = {
  325. { "role", "assistant" },
  326. { "content", "<think>I'm\nthinking</think>" },
  327. { "tool_calls", {
  328. {
  329. { "type", "function" },
  330. { "function", {
  331. { "name", "special_function" },
  332. { "arguments", "{\"arg1\": 1}" },
  333. }},
  334. },
  335. }},
  336. };
  337. json message_assist_call_id {
  338. { "role", "assistant"},
  339. { "content", {}},
  340. { "tool_calls", {
  341. {
  342. { "type", "function" },
  343. { "function", {
  344. { "name", "special_function" },
  345. { "arguments", "{\"arg1\": 1}" },
  346. }},
  347. {"id", "123456789"},
  348. },
  349. }},
  350. { "role", "assistant" },
  351. { "content", {} },
  352. { "tool_calls", tool_calls }
  353. };
  354. json message_assist_call_idx {
  355. { "role", "assistant"},
  356. { "content", {}},
  357. { "tool_calls", {
  358. {
  359. { "type", "function" },
  360. { "function", {
  361. { "name", "special_function" },
  362. { "arguments", "{\"arg1\": 1}" },
  363. }},
  364. // Index of the tool call in the tool_calls array
  365. {"id", "0"},
  366. },
  367. }},
  368. { "role", "assistant" },
  369. { "content", {} },
  370. { "tool_calls", tool_calls }
  371. };
  372. json message_assist_call_tool_plan_idx = message_assist_call_idx;
  373. message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
  374. auto python_message_assist_call = json{
  375. { "role", "assistant" },
  376. { "content", {} },
  377. { "tool_calls", json{ {
  378. { "type", "function" },
  379. { "function",
  380. {
  381. { "name", "python" },
  382. { "arguments",
  383. {
  384. { "code", "print('hey')" },
  385. } },
  386. } },
  387. } } }
  388. };
  389. auto code_interpreter_message_assist_call = json{
  390. { "role", "assistant" },
  391. { "content", {} },
  392. { "tool_calls", json{ {
  393. { "type", "function" },
  394. { "function",
  395. {
  396. { "name", "code_interpreter" },
  397. { "arguments",
  398. {
  399. { "code", "print('hey')" },
  400. } },
  401. } },
  402. } } }
  403. };
  404. common_chat_inputs inputs_no_tools;
  405. inputs_no_tools.messages = json::array({message_user});
  406. inputs_no_tools.extract_reasoning = false;
  407. common_chat_inputs inputs_no_tools_think;
  408. inputs_no_tools_think.messages = json::array({message_user});
  409. inputs_no_tools_think.extract_reasoning = true;
  410. common_chat_inputs inputs_tools;
  411. inputs_tools.messages = json::array({message_user});
  412. inputs_tools.tools = json::array({special_function_tool});
  413. inputs_tools.extract_reasoning = false;
  414. common_chat_inputs inputs_tools_think;
  415. inputs_tools_think.messages = json::array({message_user});
  416. inputs_tools_think.tools = json::array({special_function_tool});
  417. inputs_tools_think.extract_reasoning = true;
  418. common_chat_inputs inputs_tools_builtin;
  419. inputs_tools_builtin.messages = json::array({message_user});
  420. inputs_tools_builtin.tools = json::array({python_tool});
  421. inputs_tools_builtin.extract_reasoning = false;
  422. {
  423. // Not supported yet
  424. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
  425. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  426. }
  427. {
  428. const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
  429. std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
  430. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
  431. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
  432. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
  433. assert_msg_equals(msg_from_json(message_assist),
  434. common_chat_parse(
  435. "Hello, world!\nWhat's up?",
  436. COMMON_CHAT_FORMAT_COMMAND_R7B));
  437. assert_msg_equals(msg_from_json(message_assist),
  438. common_chat_parse(
  439. "Hello, world!\nWhat's up?<|END_RESPONSE|>",
  440. COMMON_CHAT_FORMAT_COMMAND_R7B));
  441. assert_msg_equals(msg_from_json(message_assist),
  442. common_chat_parse(
  443. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  444. COMMON_CHAT_FORMAT_COMMAND_R7B));
  445. assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
  446. common_chat_parse(
  447. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  448. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  449. COMMON_CHAT_FORMAT_COMMAND_R7B));
  450. assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
  451. common_chat_parse(
  452. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  453. "Hello, world!\nWhat's up?<|END_RESPONSE|>",
  454. COMMON_CHAT_FORMAT_COMMAND_R7B));
  455. assert_msg_equals(msg_from_json(message_assist_thoughts),
  456. common_chat_parse(
  457. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  458. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  459. COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
  460. test_template(tmpl, end_tokens, message_assist_call_idx, tools,
  461. "<|START_THINKING|><|END_THINKING|>"
  462. "<|START_ACTION|>[\n"
  463. " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
  464. "]<|END_ACTION|>");
  465. test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
  466. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  467. "<|START_ACTION|>[\n"
  468. " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
  469. "]<|END_ACTION|>",
  470. /* expect_grammar_triggered= */ true,
  471. /* test_grammar_if_triggered= */ true,
  472. /* think= */ true);
  473. test_template(tmpl, end_tokens, message_assist, tools,
  474. "<|START_RESPONSE|>Hello, world!\n"
  475. "What's up?<|END_RESPONSE|>",
  476. /* expect_grammar_triggered= */ false);
  477. }
  478. {
  479. const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
  480. std::vector<std::string> end_tokens{ "<end_of_turn>" };
  481. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  482. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  483. assert_equals(COMMON_CHAT_FORMAT_GENERIC,
  484. common_chat_params_init(
  485. common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
  486. "<s>", "</s>"),
  487. inputs_tools)
  488. .format);
  489. // Generic tool calls doesn't generate / parse content-only messages symmetrically.
  490. assert_msg_equals(msg_from_json(message_assist),
  491. common_chat_parse("{\n"
  492. " \"response\": \"Hello, world!\\nWhat's up?\"\n"
  493. "}",
  494. common_chat_params_init(tmpl, inputs_tools).format));
  495. test_template(tmpl, end_tokens, message_assist_call_id, tools,
  496. "{\n"
  497. " \"tool_calls\": [\n"
  498. " {\n"
  499. " \"name\": \"special_function\",\n"
  500. " \"arguments\": {\n"
  501. " \"arg1\": 1\n"
  502. " },\n"
  503. " \"id\": \"123456789\"\n"
  504. " }\n"
  505. " ]\n"
  506. "}");
  507. }
  508. {
  509. const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
  510. "</s>");
  511. std::vector<std::string> end_tokens{ "</s>" };
  512. assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
  513. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  514. test_template(
  515. tmpl, end_tokens, message_assist_call_id, tools,
  516. "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
  517. }
  518. {
  519. const common_chat_template tmpl(
  520. read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
  521. std::vector<std::string> end_tokens{ "<|im_end|>" };
  522. assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
  523. assert_equals(
  524. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  525. common_chat_params_init(
  526. common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
  527. "<s>", "</s>"),
  528. inputs_tools)
  529. .format);
  530. assert_equals(
  531. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  532. common_chat_params_init(
  533. common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
  534. inputs_tools)
  535. .format);
  536. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  537. test_template(tmpl, end_tokens, message_assist_call, tools,
  538. "<tool_call>\n"
  539. "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
  540. "</tool_call>");
  541. test_template(tmpl, end_tokens, python_message_assist_call, tools,
  542. "<tool_call>\n"
  543. "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
  544. "</tool_call>");
  545. }
  546. {
  547. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
  548. "</s>");
  549. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  550. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  551. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  552. common_chat_params_init(tmpl, inputs_tools_builtin).format);
  553. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  554. common_chat_params_init(
  555. common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
  556. "<s>", "</s>"),
  557. inputs_tools_builtin)
  558. .format);
  559. // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
  560. test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
  561. "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
  562. test_template(tmpl, end_tokens, python_message_assist_call, tools,
  563. "<|python_tag|>python.call(code=\"print('hey')\")");
  564. test_template(tmpl, end_tokens, message_assist_call, tools,
  565. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  566. }
  567. {
  568. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
  569. "</s>");
  570. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  571. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  572. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  573. test_template(tmpl, end_tokens, message_assist_call, tools,
  574. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  575. }
  576. {
  577. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
  578. "</s>");
  579. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  580. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  581. common_chat_params_init(tmpl, inputs_tools).format);
  582. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  583. test_template(tmpl, end_tokens, message_assist_call, tools,
  584. "<function=special_function>{\"arg1\": 1}</function>");
  585. }
  586. {
  587. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
  588. "</s>");
  589. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  590. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
  591. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
  592. test_template(tmpl, end_tokens, message_assist, {},
  593. "all\n"
  594. "Hello, world!\n"
  595. "What's up?",
  596. /* expect_grammar_triggered= */ false);
  597. test_template(tmpl, end_tokens, message_assist_call, tools,
  598. "special_function\n"
  599. "{\"arg1\": 1}");
  600. }
  601. {
  602. const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
  603. "</s>");
  604. std::vector<std::string> end_tokens{ "<|eot_id|>" };
  605. assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
  606. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  607. test_template(tmpl, end_tokens, message_assist_call, tools,
  608. " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
  609. }
  610. {
  611. // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
  612. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
  613. "<s>", "</s>");
  614. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  615. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
  616. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
  617. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  618. test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  619. assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
  620. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  621. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  622. assert_msg_equals(msg_from_json(message_assist_thoughts),
  623. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  624. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  625. assert_msg_equals(msg_from_json(message_assist_thoughts),
  626. // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
  627. common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
  628. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  629. // test_template(tmpl, end_tokens, message_assist_call, tools,
  630. // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  631. // "```json\n"
  632. // "{\"arg1\": 1}\n"
  633. // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
  634. // "```<|tool▁call▁end|>",
  635. // /* expect_grammar_triggered= */ true,
  636. // /* test_grammar_if_triggered= */ false);
  637. }
  638. {
  639. // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
  640. const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
  641. "<s>", "</s>");
  642. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  643. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
  644. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
  645. test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  646. test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  647. assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
  648. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  649. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  650. assert_msg_equals(msg_from_json(message_assist_thoughts),
  651. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  652. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  653. assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
  654. common_chat_parse(
  655. "<think>I'm\nthinking</think>\n\n"
  656. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  657. "```json\n"
  658. "{\"arg1\": 1}\n"
  659. "```<|tool▁call▁end|><|tool▁calls▁end|>",
  660. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  661. assert_msg_equals(msg_from_json(message_assist_call_thoughts),
  662. common_chat_parse(
  663. "<think>I'm\nthinking</think>\n\n"
  664. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  665. "```json\n"
  666. "{\"arg1\": 1}\n"
  667. "```<|tool▁call▁end|><|tool▁calls▁end|>",
  668. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  669. test_template(tmpl, end_tokens, message_assist_call, tools,
  670. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  671. "```json\n"
  672. "{\"arg1\": 1}\n"
  673. "```<|tool▁call▁end|><|tool▁calls▁end|>");
  674. }
  675. }
  676. int main(int argc, char ** argv) {
  677. #ifndef _WIN32
  678. if (argc > 1) {
  679. common_chat_inputs inputs;
  680. inputs.messages = {
  681. { { "role", "user" }, { "content", "Hey" } }
  682. };
  683. inputs.tools = json::array({ special_function_tool });
  684. std::cout << "| Template | Format |\n";
  685. std::cout << "|----------|--------|\n";
  686. for (int i = 1; i < argc; i++) {
  687. try {
  688. std::string path = argv[i];
  689. if (path.rfind(".jinja") != path.size() - 6) {
  690. std::cerr << "Skipping non-jinja file: " << path << std::endl;
  691. continue;
  692. }
  693. common_chat_template tmpl(read_file(path), "", "");
  694. auto parts = string_split(path, "/");
  695. auto name = parts[parts.size() - 1];
  696. auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
  697. std::cout << "| " << name << " | " << format << " |\n";
  698. } catch (const std::exception & e) {
  699. std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
  700. }
  701. }
  702. } else
  703. #endif
  704. {
  705. test_template_output_parsers();
  706. std::cout << "\n[chat] All tests passed!" << std::endl;
  707. }
  708. return 0;
  709. }