test-chat.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <fstream>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat.h"
  13. #include "llama-grammar.h"
  14. #include "unicode.h"
  15. using json = nlohmann::ordered_json;
  16. template <class T> static void assert_equals(const T & expected, const T & actual) {
  17. if (expected != actual) {
  18. std::cerr << "Expected: " << expected << std::endl;
  19. std::cerr << "Actual: " << actual << std::endl;
  20. std::cerr << std::flush;
  21. throw std::runtime_error("Test failed");
  22. }
  23. }
  24. static std::string read_file(const std::string & path) {
  25. std::cerr << "# Reading: " << path << '\n' << std::flush;
  26. std::ifstream fs(path, std::ios_base::binary);
  27. if (!fs.is_open()) {
  28. fs = std::ifstream("../" + path, std::ios_base::binary);
  29. if (!fs.is_open()) {
  30. throw std::runtime_error("Failed to open file: " + path);
  31. }
  32. }
  33. fs.seekg(0, std::ios_base::end);
  34. auto size = fs.tellg();
  35. fs.seekg(0);
  36. std::string out;
  37. out.resize(static_cast<size_t>(size));
  38. fs.read(out.data(), static_cast<std::streamsize>(size));
  39. return out;
  40. }
  41. static common_chat_templates_ptr read_templates(const std::string & path) {
  42. return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
  43. }
  44. static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
  45. return std::unique_ptr<llama_grammar>(
  46. llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
  47. }
  48. // TODO: extract to common helper (copied from test-grammar-integration.cpp)
  49. static bool match_string(const std::string & input, llama_grammar * grammar) {
  50. const auto cpts = unicode_cpts_from_utf8(input);
  51. auto & stacks_cur = llama_grammar_get_stacks(grammar);
  52. for (const auto & cpt : cpts) {
  53. llama_grammar_accept(grammar, cpt);
  54. if (stacks_cur.empty()) {
  55. // no stacks means that the grammar failed to match at this point
  56. return false;
  57. }
  58. }
  59. if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
  60. // An empty stack means that the grammar has been completed
  61. return true;
  62. }
  63. return false;
  64. }
  65. static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
  66. assert_equals(expected.role, actual.role);
  67. assert_equals(expected.content, actual.content);
  68. assert_equals(expected.content_parts.size(), actual.content_parts.size());
  69. for (size_t i = 0; i < expected.content_parts.size(); i++) {
  70. const auto & expected_part = expected.content_parts[i];
  71. const auto & actual_part = actual.content_parts[i];
  72. assert_equals(expected_part.type, actual_part.type);
  73. assert_equals(expected_part.text, actual_part.text);
  74. }
  75. assert_equals(expected.reasoning_content, actual.reasoning_content);
  76. assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
  77. for (size_t i = 0; i < expected.tool_calls.size(); i++) {
  78. const auto & expected_tool_call = expected.tool_calls[i];
  79. const auto & actual_tool_call = actual.tool_calls[i];
  80. assert_equals(expected_tool_call.name, actual_tool_call.name);
  81. assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
  82. assert_equals(expected_tool_call.id, actual_tool_call.id);
  83. }
  84. }
  85. common_chat_tool special_function_tool {
  86. /* .name = */ "special_function",
  87. /* .description = */ "I'm special",
  88. /* .parameters = */ R"({
  89. "type": "object",
  90. "properties": {
  91. "arg1": {
  92. "type": "integer",
  93. "description": "The arg."
  94. }
  95. },
  96. "required": ["arg1"]
  97. })",
  98. };
  99. common_chat_tool python_tool {
  100. /* .name = */ "python",
  101. /* .description = */ "an ipython interpreter",
  102. /* .parameters = */ R"({
  103. "type": "object",
  104. "properties": {
  105. "code": {
  106. "type": "string",
  107. "description": "Python code to execute."
  108. }
  109. },
  110. "required": ["code"]
  111. })",
  112. };
  113. common_chat_tool code_interpreter_tool {
  114. /* .name = */ "code_interpreter",
  115. /* .description = */ "an ipython interpreter",
  116. /* .parameters = */ R"({
  117. "type": "object",
  118. "properties": {
  119. "code": {
  120. "type": "string",
  121. "description": "Python code to execute."
  122. }
  123. },
  124. "required": ["code"]
  125. })",
  126. };
  127. std::vector<common_chat_tool> tools { special_function_tool, python_tool };
  128. std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
  129. struct delta_data {
  130. std::string delta;
  131. common_chat_params params;
  132. };
  133. static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
  134. const common_chat_msg & user_message,
  135. const common_chat_msg & delta_message,
  136. const std::vector<common_chat_tool> & tools,
  137. const common_chat_tool_choice & tool_choice,
  138. bool think = false) {
  139. common_chat_templates_inputs inputs;
  140. inputs.parallel_tool_calls = true;
  141. inputs.messages.push_back(user_message);
  142. inputs.tools = tools;
  143. inputs.tool_choice = tool_choice;
  144. inputs.extract_reasoning = think;
  145. auto params_prefix = common_chat_templates_apply(tmpls, inputs);
  146. inputs.messages.push_back(delta_message);
  147. inputs.add_generation_prompt = false;
  148. auto params_full = common_chat_templates_apply(tmpls, inputs);
  149. std::string prefix = params_prefix.prompt;
  150. std::string full = params_full.prompt;
  151. if (full == prefix) {
  152. throw std::runtime_error("Full message is the same as the prefix");
  153. }
  154. size_t common_prefix_length = 0;
  155. for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
  156. if (prefix[i] != full[i]) {
  157. break;
  158. }
  159. if (prefix[i] == '<') {
  160. // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
  161. // but it removes thinking tags for past messages.
  162. // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
  163. continue;
  164. }
  165. common_prefix_length = i + 1;
  166. }
  167. auto delta = full.substr(common_prefix_length);
  168. // Strip end tokens
  169. for (const auto & end_token : end_tokens) {
  170. // rfind to find the last occurrence
  171. auto pos = delta.rfind(end_token);
  172. if (pos != std::string::npos) {
  173. delta = delta.substr(0, pos);
  174. break;
  175. }
  176. }
  177. return { delta, params_full };
  178. }
  179. /*
  180. Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
  181. gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
  182. the parsed message is the same as the test_message
  183. */
  184. static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
  185. const common_chat_msg & test_message,
  186. const std::vector<common_chat_tool> & tools = {},
  187. const std::string & expected_delta = "",
  188. bool expect_grammar_triggered = true,
  189. bool test_grammar_if_triggered = true,
  190. bool think = false) {
  191. common_chat_msg user_message;
  192. user_message.role = "user";
  193. user_message.content = "Hello, world!";
  194. for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
  195. auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
  196. if (!expected_delta.empty()) {
  197. assert_equals(expected_delta, data.delta);
  198. }
  199. if (expect_grammar_triggered) {
  200. const auto msg = common_chat_parse(data.delta, data.params.format);
  201. assert_msg_equals(test_message, msg);
  202. }
  203. if (!test_message.tool_calls.empty()) {
  204. GGML_ASSERT(!data.params.grammar.empty());
  205. }
  206. if (!data.params.grammar.empty()) {
  207. auto grammar = build_grammar(data.params.grammar);
  208. if (!grammar) {
  209. throw std::runtime_error("Failed to build grammar");
  210. }
  211. auto earliest_trigger_pos = std::string::npos;
  212. auto constrained = data.delta;
  213. for (const auto & trigger : data.params.grammar_triggers) {
  214. auto pos = constrained.find(trigger.word);
  215. if (pos == std::string::npos) {
  216. continue;
  217. }
  218. if (pos > 0 && trigger.at_start) {
  219. fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
  220. continue;
  221. }
  222. if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
  223. earliest_trigger_pos = pos;
  224. }
  225. }
  226. auto grammar_triggered = false;
  227. if (earliest_trigger_pos != std::string::npos) {
  228. constrained = constrained.substr(earliest_trigger_pos);
  229. grammar_triggered = true;
  230. }
  231. if (data.params.grammar_lazy) {
  232. assert_equals(expect_grammar_triggered, grammar_triggered);
  233. }
  234. if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
  235. throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
  236. "\n\nGrammar: " + data.params.grammar);
  237. }
  238. }
  239. }
  240. }
  241. const common_chat_msg message_user {
  242. "user",
  243. "Hey there!",
  244. /* .content_parts = */ {},
  245. /* .tool_calls = */ {},
  246. /* .reasoning_content = */ "",
  247. /* .tool_name = */ "",
  248. /* .tool_call_id = */ "",
  249. };
  250. const common_chat_msg message_user_parts {
  251. "user",
  252. /* .content = */ "",
  253. /* .content_parts = */ {
  254. { "text", "Hey" },
  255. { "text", "there" },
  256. },
  257. /* .tool_calls = */ {},
  258. /* .reasoning_content = */ "",
  259. /* .tool_name = */ "",
  260. /* .tool_call_id = */ "",
  261. };
  262. const common_chat_msg message_assist {
  263. "assistant",
  264. "Hello, world!\nWhat's up?",
  265. /* .content_parts = */ {},
  266. /* .tool_calls = */ {},
  267. /* .reasoning_content = */ "",
  268. /* .tool_name = */ "",
  269. /* .tool_call_id = */ "",
  270. };
  271. const common_chat_msg message_assist_thoughts_unparsed_think {
  272. "assistant",
  273. "<think>I'm thinking</think>Hello, world!\nWhat's up?",
  274. /* .content_parts = */ {},
  275. /* .tool_calls = */ {},
  276. /* .reasoning_content = */ "",
  277. /* .tool_name = */ "",
  278. /* .tool_call_id = */ "",
  279. };
  280. const common_chat_msg message_assist_thoughts_unparsed_r7b {
  281. "assistant",
  282. "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
  283. /* .content_parts = */ {},
  284. /* .tool_calls = */ {},
  285. /* .reasoning_content = */ "",
  286. /* .tool_name = */ "",
  287. /* .tool_call_id = */ "",
  288. };
  289. const common_chat_msg message_assist_thoughts {
  290. "assistant",
  291. "Hello, world!\nWhat's up?",
  292. /* .content_parts = */ {},
  293. /* .tool_calls = */ {},
  294. /* .reasoning_content = */ "I'm thinking",
  295. /* .tool_name = */ "",
  296. /* .tool_call_id = */ "",
  297. };
  298. const std::vector<common_chat_tool_call> tool_calls {
  299. { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
  300. };
  301. const std::vector<common_chat_tool_call> tool_calls_idx {
  302. { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
  303. };
  304. const std::vector<common_chat_tool_call> tool_calls_id {
  305. { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
  306. };
  307. const common_chat_msg message_assist_call {
  308. "assistant",
  309. "",
  310. /* .content_parts = */ {},
  311. tool_calls,
  312. /* .reasoning_content = */ "",
  313. /* .tool_name = */ "",
  314. /* .tool_call_id = */ "",
  315. };
  316. const common_chat_msg message_assist_call_thoughts = {
  317. "assistant",
  318. /* .content = */ "",
  319. /* .content_parts = */ {},
  320. tool_calls,
  321. /* .reasoning_content = */ "I'm\nthinking",
  322. /* .tool_name = */ "",
  323. /* .tool_call_id = */ "",
  324. };
  325. const common_chat_msg message_assist_call_thoughts_unparsed = {
  326. "assistant",
  327. /* .content = */ "<think>I'm\nthinking</think>",
  328. /* .content_parts = */ {},
  329. tool_calls,
  330. /* .reasoning_content = */ "",
  331. /* .tool_name = */ "",
  332. /* .tool_call_id = */ "",
  333. };
  334. const common_chat_msg message_assist_call_id {
  335. "assistant",
  336. "",
  337. /* .content_parts = */ {},
  338. tool_calls_id,
  339. /* .reasoning_content = */ "",
  340. /* .tool_name = */ "",
  341. /* .tool_call_id = */ "",
  342. };
  343. const common_chat_msg message_assist_call_idx {
  344. "assistant",
  345. "",
  346. /* .content_parts = */ {},
  347. tool_calls_idx,
  348. /* .reasoning_content = */ "",
  349. /* .tool_name = */ "",
  350. /* .tool_call_id = */ "",
  351. };
  352. const common_chat_msg message_assist_call_python {
  353. "assistant",
  354. "",
  355. /* .content_parts = */ {},
  356. { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
  357. /* .reasoning_content = */ "",
  358. /* .tool_name = */ "",
  359. /* .tool_call_id = */ "",
  360. };
  361. const common_chat_msg message_assist_call_code_interpreter {
  362. "assistant",
  363. "",
  364. /* .content_parts = */ {},
  365. { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
  366. /* .reasoning_content = */ "",
  367. /* .tool_name = */ "",
  368. /* .tool_call_id = */ "",
  369. };
  370. static void test_msgs_oaicompat_json_conversion() {
  371. std::vector<common_chat_msg> msgs{
  372. message_user,
  373. message_user_parts,
  374. message_assist_call,
  375. message_assist_call_thoughts,
  376. message_assist_call_thoughts_unparsed,
  377. message_assist_call_id,
  378. message_assist_call_idx,
  379. message_assist_call_python,
  380. message_assist_call_code_interpreter,
  381. };
  382. for (const auto & msg : msgs) {
  383. auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
  384. auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
  385. assert_equals((size_t) 1, msgs2.size());
  386. auto msg2 = msgs2[0];
  387. assert_msg_equals(msg, msg2);
  388. }
  389. assert_equals(
  390. std::string(
  391. "[\n"
  392. " {\n"
  393. " \"role\": \"user\",\n"
  394. " \"content\": [\n"
  395. " {\n"
  396. " \"type\": \"text\",\n"
  397. " \"text\": \"Hey\"\n"
  398. " },\n"
  399. " {\n"
  400. " \"type\": \"text\",\n"
  401. " \"text\": \"there\"\n"
  402. " }\n"
  403. " ]\n"
  404. " }\n"
  405. "]"
  406. ),
  407. common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
  408. assert_equals(
  409. std::string(
  410. "[\n"
  411. " {\n"
  412. " \"role\": \"assistant\",\n"
  413. " \"content\": null,\n"
  414. " \"tool_calls\": [\n"
  415. " {\n"
  416. " \"type\": \"function\",\n"
  417. " \"function\": {\n"
  418. " \"name\": \"python\",\n"
  419. " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
  420. " }\n"
  421. " }\n"
  422. " ]\n"
  423. " }\n"
  424. "]"
  425. ),
  426. common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
  427. }
  428. static void test_tools_oaicompat_json_conversion() {
  429. std::vector<common_chat_tool> tools{
  430. special_function_tool,
  431. python_tool,
  432. code_interpreter_tool,
  433. };
  434. for (const auto & tool : tools) {
  435. auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
  436. auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
  437. assert_equals((size_t) 1, tools2.size());
  438. auto tool2 = tools2[0];
  439. assert_equals(tool.name, tool2.name);
  440. assert_equals(tool.description, tool2.description);
  441. assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
  442. }
  443. assert_equals(
  444. std::string(
  445. "[\n"
  446. " {\n"
  447. " \"type\": \"function\",\n"
  448. " \"function\": {\n"
  449. " \"name\": \"special_function\",\n"
  450. " \"description\": \"I'm special\",\n"
  451. " \"parameters\": {\n"
  452. " \"type\": \"object\",\n"
  453. " \"properties\": {\n"
  454. " \"arg1\": {\n"
  455. " \"type\": \"integer\",\n"
  456. " \"description\": \"The arg.\"\n"
  457. " }\n"
  458. " },\n"
  459. " \"required\": [\n"
  460. " \"arg1\"\n"
  461. " ]\n"
  462. " }\n"
  463. " }\n"
  464. " }\n"
  465. "]"
  466. ),
  467. common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
  468. }
  469. static void test_template_output_parsers() {
  470. common_chat_templates_inputs inputs_no_tools;
  471. inputs_no_tools.messages = {message_user};
  472. inputs_no_tools.extract_reasoning = false;
  473. common_chat_templates_inputs inputs_no_tools_think;
  474. inputs_no_tools_think.messages = {message_user};
  475. inputs_no_tools_think.extract_reasoning = true;
  476. common_chat_templates_inputs inputs_tools;
  477. inputs_tools.messages = {message_user};
  478. inputs_tools.tools = {special_function_tool};
  479. inputs_tools.extract_reasoning = false;
  480. common_chat_templates_inputs inputs_tools_think;
  481. inputs_tools_think.messages = {message_user};
  482. inputs_tools_think.tools = {special_function_tool};
  483. inputs_tools_think.extract_reasoning = true;
  484. common_chat_templates_inputs inputs_tools_builtin;
  485. inputs_tools_builtin.messages = {message_user};
  486. inputs_tools_builtin.tools = {python_tool};
  487. inputs_tools_builtin.extract_reasoning = false;
  488. {
  489. // Not supported yet
  490. auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
  491. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  492. }
  493. {
  494. auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
  495. std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
  496. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
  497. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  498. assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
  499. assert_msg_equals(message_assist,
  500. common_chat_parse(
  501. "Hello, world!\nWhat's up?",
  502. COMMON_CHAT_FORMAT_COMMAND_R7B));
  503. assert_msg_equals(message_assist,
  504. common_chat_parse(
  505. "Hello, world!\nWhat's up?<|END_RESPONSE|>",
  506. COMMON_CHAT_FORMAT_COMMAND_R7B));
  507. assert_msg_equals(message_assist,
  508. common_chat_parse(
  509. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  510. COMMON_CHAT_FORMAT_COMMAND_R7B));
  511. assert_msg_equals(message_assist_thoughts_unparsed_r7b,
  512. common_chat_parse(
  513. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  514. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  515. COMMON_CHAT_FORMAT_COMMAND_R7B));
  516. assert_msg_equals(message_assist_thoughts_unparsed_r7b,
  517. common_chat_parse(
  518. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  519. "Hello, world!\nWhat's up?<|END_RESPONSE|>",
  520. COMMON_CHAT_FORMAT_COMMAND_R7B));
  521. assert_msg_equals(message_assist_thoughts,
  522. common_chat_parse(
  523. "<|START_THINKING|>I'm thinking<|END_THINKING|>"
  524. "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
  525. COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
  526. test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
  527. "<|START_THINKING|><|END_THINKING|>"
  528. "<|START_ACTION|>[\n"
  529. " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
  530. "]<|END_ACTION|>");
  531. test_templates(tmpls.get(), end_tokens, message_assist, tools,
  532. "<|START_RESPONSE|>Hello, world!\n"
  533. "What's up?<|END_RESPONSE|>",
  534. /* expect_grammar_triggered= */ false);
  535. }
  536. {
  537. auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
  538. std::vector<std::string> end_tokens{ "<end_of_turn>" };
  539. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
  540. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  541. assert_equals(COMMON_CHAT_FORMAT_GENERIC,
  542. common_chat_templates_apply(
  543. read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
  544. inputs_tools)
  545. .format);
  546. // Generic tool calls doesn't generate / parse content-only messages symmetrically.
  547. assert_msg_equals(message_assist,
  548. common_chat_parse("{\n"
  549. " \"response\": \"Hello, world!\\nWhat's up?\"\n"
  550. "}",
  551. common_chat_templates_apply(tmpls.get(), inputs_tools).format));
  552. test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
  553. "{\n"
  554. " \"tool_calls\": [\n"
  555. " {\n"
  556. " \"name\": \"special_function\",\n"
  557. " \"arguments\": {\n"
  558. " \"arg1\": 1\n"
  559. " },\n"
  560. " \"id\": \"123456789\"\n"
  561. " }\n"
  562. " ]\n"
  563. "}");
  564. }
  565. {
  566. auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
  567. std::vector<std::string> end_tokens{ "</s>" };
  568. assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  569. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  570. test_templates(
  571. tmpls.get(), end_tokens, message_assist_call_id, tools,
  572. "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
  573. }
  574. {
  575. auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
  576. std::vector<std::string> end_tokens{ "<|im_end|>" };
  577. assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  578. assert_equals(
  579. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  580. common_chat_templates_apply(
  581. read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
  582. inputs_tools)
  583. .format);
  584. assert_equals(
  585. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  586. common_chat_templates_apply(
  587. read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
  588. inputs_tools)
  589. .format);
  590. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  591. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  592. "<tool_call>\n"
  593. "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
  594. "</tool_call>");
  595. test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
  596. "<tool_call>\n"
  597. "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
  598. "</tool_call>");
  599. }
  600. {
  601. auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
  602. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  603. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  604. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  605. common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
  606. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  607. common_chat_templates_apply(
  608. read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
  609. inputs_tools_builtin)
  610. .format);
  611. // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
  612. test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
  613. "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
  614. test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
  615. "<|python_tag|>python.call(code=\"print('hey')\")");
  616. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  617. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  618. }
  619. {
  620. auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
  621. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  622. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  623. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  624. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  625. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  626. }
  627. {
  628. auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
  629. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  630. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  631. common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  632. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  633. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  634. "<function=special_function>{\"arg1\": 1}</function>");
  635. }
  636. {
  637. auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
  638. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  639. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
  640. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  641. test_templates(tmpls.get(), end_tokens, message_assist, {},
  642. "all\n"
  643. "Hello, world!\n"
  644. "What's up?",
  645. /* expect_grammar_triggered= */ false);
  646. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  647. "special_function\n"
  648. "{\"arg1\": 1}");
  649. }
  650. {
  651. auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
  652. std::vector<std::string> end_tokens{ "<|eot_id|>" };
  653. assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  654. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  655. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  656. " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
  657. }
  658. {
  659. // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
  660. auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
  661. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  662. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  663. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
  664. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  665. test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  666. assert_msg_equals(message_assist_thoughts_unparsed_think,
  667. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  668. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  669. assert_msg_equals(message_assist_thoughts,
  670. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  671. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  672. assert_msg_equals(message_assist_thoughts,
  673. // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
  674. common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
  675. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  676. // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  677. // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  678. // "```json\n"
  679. // "{\"arg1\": 1}\n"
  680. // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
  681. // "```<|tool▁call▁end|>",
  682. // /* expect_grammar_triggered= */ true,
  683. // /* test_grammar_if_triggered= */ false);
  684. }
  685. {
  686. // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
  687. auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
  688. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  689. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
  690. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
  691. test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  692. test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
  693. assert_msg_equals(message_assist_thoughts_unparsed_think,
  694. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  695. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  696. assert_msg_equals(message_assist_thoughts,
  697. common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
  698. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  699. assert_msg_equals(message_assist_call_thoughts_unparsed,
  700. common_chat_parse(
  701. "<think>I'm\nthinking</think>\n\n"
  702. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  703. "```json\n"
  704. "{\"arg1\": 1}\n"
  705. "```<|tool▁call▁end|><|tool▁calls▁end|>",
  706. COMMON_CHAT_FORMAT_DEEPSEEK_R1));
  707. assert_msg_equals(message_assist_call_thoughts,
  708. common_chat_parse(
  709. "<think>I'm\nthinking</think>\n\n"
  710. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  711. "```json\n"
  712. "{\"arg1\": 1}\n"
  713. "```<|tool▁call▁end|><|tool▁calls▁end|>",
  714. COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
  715. test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
  716. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  717. "```json\n"
  718. "{\"arg1\": 1}\n"
  719. "```<|tool▁call▁end|><|tool▁calls▁end|>");
  720. }
  721. }
  722. int main(int argc, char ** argv) {
  723. try {
  724. #ifndef _WIN32
  725. if (argc > 1) {
  726. common_chat_templates_inputs inputs;
  727. common_chat_msg msg;
  728. msg.role = "user";
  729. msg.content = "Hey";
  730. inputs.messages = {msg};
  731. inputs.tools = { special_function_tool };
  732. std::cout << "| Template | Format |\n";
  733. std::cout << "|----------|--------|\n";
  734. for (int i = 1; i < argc; i++) {
  735. try {
  736. std::string path = argv[i];
  737. if (path.rfind(".jinja") != path.size() - 6) {
  738. std::cerr << "Skipping non-jinja file: " << path << '\n';
  739. continue;
  740. }
  741. auto tmpls = read_templates(path);
  742. auto parts = string_split(path, "/");
  743. auto name = parts[parts.size() - 1];
  744. auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
  745. std::cout << "| " << name << " | " << format << " |\n";
  746. } catch (const std::exception & e) {
  747. std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
  748. }
  749. }
  750. } else
  751. #endif
  752. {
  753. test_msgs_oaicompat_json_conversion();
  754. test_tools_oaicompat_json_conversion();
  755. test_template_output_parsers();
  756. std::cout << "\n[chat] All tests passed!" << '\n';
  757. }
  758. return 0;
  759. } catch (const std::exception & e) {
  760. std::cerr << "Error: " << e.what() << '\n';
  761. return 1;
  762. }
  763. }