1
0

test-chat.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <fstream>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat-template.hpp"
  13. #include "chat.hpp"
  14. #include "llama-grammar.h"
  15. #include "unicode.h"
  16. using json = nlohmann::ordered_json;
  17. static common_chat_msg msg_from_json(const json & message) {
  18. common_chat_msg ret{
  19. "assistant",
  20. "",
  21. {},
  22. };
  23. if (message.contains("content") && !message.at("content").is_null()) {
  24. ret.content = message.at("content").get<std::string>();
  25. }
  26. auto has_tool_calls = message.contains("tool_calls");
  27. if (has_tool_calls) {
  28. for (const auto & tc : message.at("tool_calls")) {
  29. const auto & arguments = tc.at("function").at("arguments");
  30. ret.tool_calls.push_back({
  31. tc.at("function").at("name").get<std::string>(),
  32. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  33. tc.contains("id") ? tc.at("id").get<std::string>() : "",
  34. });
  35. }
  36. }
  37. return ret;
  38. }
  39. template <class T> static void assert_equals(const T & expected, const T & actual) {
  40. if (expected != actual) {
  41. std::cerr << "Expected: " << expected << std::endl;
  42. std::cerr << "Actual: " << actual << std::endl;
  43. std::cerr << std::flush;
  44. throw std::runtime_error("Test failed");
  45. }
  46. }
  47. static std::string read_file(const std::string & path) {
  48. std::cerr << "# Reading: " << path << std::endl << std::flush;
  49. std::ifstream fs(path, std::ios_base::binary);
  50. if (!fs.is_open()) {
  51. fs = std::ifstream("../" + path, std::ios_base::binary);
  52. if (!fs.is_open()) {
  53. throw std::runtime_error("Failed to open file: " + path);
  54. }
  55. }
  56. fs.seekg(0, std::ios_base::end);
  57. auto size = fs.tellg();
  58. fs.seekg(0);
  59. std::string out;
  60. out.resize(static_cast<size_t>(size));
  61. fs.read(&out[0], static_cast<std::streamsize>(size));
  62. return out;
  63. }
  64. static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
  65. return std::unique_ptr<llama_grammar>(
  66. llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
  67. }
  68. // TODO: extract to common helper (copied from test-grammar-integration.cpp)
  69. static bool match_string(const std::string & input, llama_grammar * grammar) {
  70. const auto cpts = unicode_cpts_from_utf8(input);
  71. auto & stacks_cur = llama_grammar_get_stacks(grammar);
  72. for (const auto & cpt : cpts) {
  73. llama_grammar_accept(grammar, cpt);
  74. if (stacks_cur.empty()) {
  75. // no stacks means that the grammar failed to match at this point
  76. return false;
  77. }
  78. }
  79. for (const auto & stack : stacks_cur) {
  80. if (stack.empty()) {
  81. // An empty stack means that the grammar has been completed
  82. return true;
  83. }
  84. }
  85. return false;
  86. }
  87. // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
  88. static std::string dump(const json & j) {
  89. return minja::Value(j).dump(-1, /* to_json= */ true);
  90. }
  91. static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
  92. assert_equals(expected.role, actual.role);
  93. assert_equals(expected.content, actual.content);
  94. assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
  95. for (size_t i = 0; i < expected.tool_calls.size(); i++) {
  96. const auto & expected_tool_call = expected.tool_calls[i];
  97. const auto & actual_tool_call = actual.tool_calls[i];
  98. assert_equals(expected_tool_call.name, actual_tool_call.name);
  99. assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
  100. assert_equals(expected_tool_call.id, actual_tool_call.id);
  101. }
  102. }
  103. const auto special_function_tool = json::parse(R"({
  104. "type": "function",
  105. "function": {
  106. "name": "special_function",
  107. "description": "I'm special",
  108. "parameters": {
  109. "type": "object",
  110. "properties": {
  111. "arg1": {
  112. "type": "integer",
  113. "description": "The arg."
  114. }
  115. },
  116. "required": ["arg1"]
  117. }
  118. }
  119. })");
  120. const auto python_tool = json::parse(R"({
  121. "type": "function",
  122. "function": {
  123. "name": "python",
  124. "description": "an ipython interpreter",
  125. "parameters": {
  126. "type": "object",
  127. "properties": {
  128. "code": {
  129. "type": "string",
  130. "description": "Python code to execute."
  131. }
  132. },
  133. "required": ["code"]
  134. }
  135. }
  136. })");
  137. const auto code_interpreter_tool = json::parse(R"({
  138. "type": "function",
  139. "function": {
  140. "name": "code_interpreter",
  141. "description": "an ipython interpreter",
  142. "parameters": {
  143. "type": "object",
  144. "properties": {
  145. "code": {
  146. "type": "string",
  147. "description": "Python code to execute."
  148. }
  149. },
  150. "required": ["code"]
  151. }
  152. }
  153. })");
  154. const json tools = { special_function_tool, python_tool };
  155. const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
  156. struct delta_data {
  157. std::string delta;
  158. std::string grammar;
  159. common_chat_format format;
  160. };
  161. static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  162. const json & user_message, const json & delta_message, const json & tools,
  163. const json & tool_choice) {
  164. common_chat_inputs inputs;
  165. inputs.parallel_tool_calls = true;
  166. inputs.messages = json::array();
  167. inputs.messages.push_back(user_message);
  168. inputs.tools = tools;
  169. inputs.tool_choice = tool_choice;
  170. auto params_prefix = common_chat_params_init(tmpl, inputs);
  171. inputs.messages.push_back(delta_message);
  172. inputs.add_generation_prompt = false;
  173. auto params_full = common_chat_params_init(tmpl, inputs);
  174. std::string prefix = params_prefix.prompt;
  175. std::string full = params_full.prompt;
  176. // Check full starts with prefix
  177. if (full.find(prefix) != 0) {
  178. fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
  179. throw std::runtime_error("Full message does not start with prefix");
  180. }
  181. if (full == prefix) {
  182. throw std::runtime_error("Full message is the same as the prefix");
  183. }
  184. auto delta = full.substr(prefix.size());
  185. // Strip end tokens
  186. for (const auto & end_token : end_tokens) {
  187. // rfind to find the last occurrence
  188. auto pos = delta.rfind(end_token);
  189. if (pos != std::string::npos) {
  190. delta = delta.substr(0, pos);
  191. break;
  192. }
  193. }
  194. return { delta, params_full.grammar, params_full.format };
  195. }
  196. /*
  197. Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
  198. gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
  199. the parsed message is the same as the test_message
  200. */
  201. static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
  202. const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
  203. bool skip_grammar_test = false, bool skip_parser_test = false) {
  204. common_chat_msg expected_msg = msg_from_json(test_message);
  205. auto user_message = json{
  206. { "role", "user" },
  207. { "content", "Hello, world!" }
  208. };
  209. for (const auto & tool_choice : json({ "auto", "required" })) {
  210. auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
  211. if (!expected_delta.empty()) {
  212. assert_equals(expected_delta, data.delta);
  213. }
  214. if (!skip_parser_test) {
  215. const auto msg = common_chat_parse(data.delta, data.format);
  216. assert_msg_equals(expected_msg, msg);
  217. }
  218. if (!expected_msg.tool_calls.empty()) {
  219. GGML_ASSERT(!data.grammar.empty());
  220. }
  221. if (!data.grammar.empty()) {
  222. auto grammar = build_grammar(data.grammar);
  223. if (!grammar) {
  224. throw std::runtime_error("Failed to build grammar");
  225. }
  226. // TODO: exercice lazy grammars + triggers here, instead of skipping the test
  227. if (!skip_grammar_test) {
  228. if (!match_string(data.delta, grammar.get())) {
  229. throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
  230. "\n\nGrammar: " + data.grammar);
  231. }
  232. }
  233. }
  234. }
  235. }
  236. static void test_template_output_parsers() {
  237. auto text_message = json{
  238. { "role", "assistant" },
  239. { "content", "Hello, world!" },
  240. };
  241. auto tool_call_message = json{
  242. { "role", "assistant" },
  243. { "content", {} },
  244. { "tool_calls", json{ {
  245. { "type", "function" },
  246. { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
  247. } } }
  248. };
  249. auto tool_call_message_with_id = json::parse(tool_call_message.dump());
  250. tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
  251. auto python_tool_call_message = json{
  252. { "role", "assistant" },
  253. { "content", {} },
  254. { "tool_calls", json{ {
  255. { "type", "function" },
  256. { "function",
  257. {
  258. { "name", "python" },
  259. { "arguments",
  260. {
  261. { "code", "print('hey')" },
  262. } },
  263. } },
  264. } } }
  265. };
  266. auto code_interpreter_tool_call_message = json{
  267. { "role", "assistant" },
  268. { "content", {} },
  269. { "tool_calls", json{ {
  270. { "type", "function" },
  271. { "function",
  272. {
  273. { "name", "code_interpreter" },
  274. { "arguments",
  275. {
  276. { "code", "print('hey')" },
  277. } },
  278. } },
  279. } } }
  280. };
  281. common_chat_inputs inputs_no_tools;
  282. inputs_no_tools.messages = {
  283. { { "role", "user" }, { "content", "Hey" } }
  284. };
  285. common_chat_inputs inputs_tools = inputs_no_tools;
  286. inputs_tools.tools = json::array();
  287. inputs_tools.tools.push_back(special_function_tool);
  288. common_chat_inputs inputs_tools_builtin = inputs_no_tools;
  289. inputs_tools_builtin.tools = json::array();
  290. inputs_tools_builtin.tools.push_back(python_tool);
  291. {
  292. const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
  293. std::vector<std::string> end_tokens{ "<end_of_turn>" };
  294. assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
  295. assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
  296. assert_equals(COMMON_CHAT_FORMAT_GENERIC,
  297. common_chat_params_init(
  298. common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
  299. "<s>", "</s>"),
  300. inputs_tools)
  301. .format);
  302. // Generic tool calls doesn't generate / parse content-only messages symmetrically.
  303. assert_msg_equals(msg_from_json(text_message),
  304. common_chat_parse("{\n"
  305. " \"response\": \"Hello, world!\"\n"
  306. "}",
  307. common_chat_params_init(tmpl, inputs_tools).format));
  308. test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
  309. "{\n"
  310. " \"tool_calls\": [\n"
  311. " {\n"
  312. " \"name\": \"special_function\",\n"
  313. " \"arguments\": {\n"
  314. " \"arg1\": 1\n"
  315. " },\n"
  316. " \"id\": \"123456789\"\n"
  317. " }\n"
  318. " ]\n"
  319. "}");
  320. }
  321. {
  322. const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
  323. "</s>");
  324. std::vector<std::string> end_tokens{ "</s>" };
  325. assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
  326. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  327. test_template(
  328. tmpl, end_tokens, tool_call_message_with_id, tools,
  329. "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
  330. /* skip_grammar_test= */ true);
  331. }
  332. {
  333. const common_chat_template tmpl(
  334. read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
  335. std::vector<std::string> end_tokens{ "<|im_end|>" };
  336. assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
  337. assert_equals(
  338. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  339. common_chat_params_init(
  340. common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
  341. "<s>", "</s>"),
  342. inputs_tools)
  343. .format);
  344. assert_equals(
  345. COMMON_CHAT_FORMAT_HERMES_2_PRO,
  346. common_chat_params_init(
  347. common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
  348. inputs_tools)
  349. .format);
  350. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  351. test_template(tmpl, end_tokens, tool_call_message, tools,
  352. "<tool_call>\n"
  353. "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
  354. "</tool_call>");
  355. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  356. "<tool_call>\n"
  357. "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
  358. "</tool_call>");
  359. }
  360. {
  361. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
  362. "</s>");
  363. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  364. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  365. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  366. common_chat_params_init(tmpl, inputs_tools_builtin).format);
  367. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
  368. common_chat_params_init(
  369. common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
  370. "<s>", "</s>"),
  371. inputs_tools_builtin)
  372. .format);
  373. // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
  374. test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
  375. "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
  376. test_template(tmpl, end_tokens, python_tool_call_message, tools,
  377. "<|python_tag|>python.call(code=\"print('hey')\")");
  378. test_template(tmpl, end_tokens, tool_call_message, tools,
  379. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  380. }
  381. {
  382. const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
  383. "</s>");
  384. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  385. assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
  386. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  387. test_template(tmpl, end_tokens, tool_call_message, tools,
  388. "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
  389. }
  390. {
  391. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
  392. "</s>");
  393. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  394. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
  395. common_chat_params_init(tmpl, inputs_tools).format);
  396. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  397. test_template(tmpl, end_tokens, tool_call_message, tools,
  398. "<function=special_function>{\"arg1\": 1}</function>");
  399. }
  400. {
  401. const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
  402. "</s>");
  403. std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
  404. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
  405. assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
  406. test_template(tmpl, end_tokens, text_message, {},
  407. "all\n"
  408. "Hello, world!",
  409. /* skip_grammar_test= */ true);
  410. test_template(tmpl, end_tokens, tool_call_message, tools,
  411. "special_function\n"
  412. "{\"arg1\": 1}");
  413. }
  414. {
  415. const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
  416. "</s>");
  417. std::vector<std::string> end_tokens{ "<|eot_id|>" };
  418. assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
  419. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  420. test_template(tmpl, end_tokens, tool_call_message, tools,
  421. " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
  422. }
  423. {
  424. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
  425. "<s>", "</s>");
  426. std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
  427. assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
  428. test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
  429. test_template(tmpl, end_tokens, tool_call_message, tools,
  430. "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
  431. "```json\n"
  432. "{\"arg1\": 1}\n"
  433. "```<|tool▁call▁end|>");
  434. }
  435. }
  436. int main(int argc, char ** argv) {
  437. #ifndef _WIN32
  438. if (argc > 1) {
  439. common_chat_inputs inputs;
  440. inputs.messages = {
  441. { { "role", "user" }, { "content", "Hey" } }
  442. };
  443. inputs.tools = json::array({ special_function_tool });
  444. std::cout << "| Template | Format |\n";
  445. std::cout << "|----------|--------|\n";
  446. for (int i = 1; i < argc; i++) {
  447. std::string path = argv[i];
  448. if (path.rfind(".jinja") != path.size() - 6) {
  449. std::cerr << "Skipping non-jinja file: " << path << std::endl;
  450. continue;
  451. }
  452. common_chat_template tmpl(read_file(path), "", "");
  453. auto parts = string_split(path, "/");
  454. auto name = parts[parts.size() - 1];
  455. std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
  456. << " |\n";
  457. }
  458. } else
  459. #endif
  460. {
  461. test_template_output_parsers();
  462. std::cout << "\n[chat] All tests passed!" << std::endl;
  463. }
  464. return 0;
  465. }