test-chat-parser.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. // Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
  2. //
  3. // Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
  4. // e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
  5. //
  6. // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
  7. //
  8. #include <exception>
  9. #include <iostream>
  10. #include <json.hpp>
  11. #include <string>
  12. #include "chat-parser.h"
  13. #include "common.h"
  14. #include "log.h"
  15. #include "regex-partial.h"
  16. using json = nlohmann::ordered_json;
  17. template <class T>
  18. static void assert_equals(const T & expected, const T & actual) {
  19. if (expected != actual) {
  20. std::cerr << "Expected: " << expected << std::endl;
  21. std::cerr << "Actual: " << actual << std::endl;
  22. std::cerr << std::flush;
  23. throw std::runtime_error("Test failed");
  24. }
  25. }
  26. static void assert_equals(const char * expected, const std::string & actual) {
  27. return assert_equals<std::string>(expected, actual);
  28. }
  29. static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
  30. try {
  31. fn();
  32. } catch (const std::exception & e) {
  33. if (expected_exception_pattern.empty()) {
  34. return;
  35. }
  36. std::regex expected_exception_regex(expected_exception_pattern);
  37. std::string actual_message = e.what();
  38. if (std::regex_search(actual_message, expected_exception_regex)) {
  39. return;
  40. }
  41. throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
  42. throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
  43. }
  44. throw std::runtime_error("Exception was expected but not thrown");
  45. }
  46. static void test_reasoning() {
  47. {
  48. common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
  49. /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
  50. /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
  51. /* .reasoning_in_content = */ false,
  52. /* .thinking_forced_open = */ false,
  53. });
  54. assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
  55. assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
  56. }
  57. {
  58. common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
  59. /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
  60. /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
  61. /* .reasoning_in_content = */ false,
  62. /* .thinking_forced_open = */ false,
  63. });
  64. assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
  65. assert_equals(std::string("Cogito"), builder.result().reasoning_content);
  66. assert_equals("Ergo sum", builder.consume_rest());
  67. }
  68. {
  69. common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
  70. /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
  71. /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
  72. /* .reasoning_in_content = */ false,
  73. /* .thinking_forced_open = */ false,
  74. });
  75. assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
  76. assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
  77. }
  78. {
  79. common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
  80. /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
  81. /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
  82. /* .reasoning_in_content = */ false,
  83. /* .thinking_forced_open = */ true,
  84. });
  85. assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
  86. assert_equals(std::string("Cogito"), builder.result().reasoning_content);
  87. assert_equals("Ergo sum", builder.consume_rest());
  88. }
  89. {
  90. common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
  91. /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
  92. /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
  93. /* .reasoning_in_content = */ true,
  94. /* .thinking_forced_open = */ true,
  95. });
  96. assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
  97. assert_equals("<think>Cogito</think>", builder.result().content);
  98. assert_equals("Ergo sum", builder.consume_rest());
  99. }
  100. }
  101. static void test_regex() {
  102. auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
  103. common_chat_msg_parser builder(input, /* is_partial= */ false, {});
  104. assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
  105. };
  106. test_throws("Hello, world!", "abc", "^abc$");
  107. test_throws("Hello, world!", "e", "^e$");
  108. {
  109. common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
  110. builder.consume_regex(common_regex("Hello"));
  111. assert_equals(", world!", builder.consume_rest());
  112. }
  113. {
  114. // When in non partial mode, we can say whether the regex was consumed or not.
  115. common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
  116. assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
  117. }
  118. {
  119. common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
  120. auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
  121. assert_equals(true, res.has_value());
  122. // Verify captures
  123. assert_equals<size_t>(2, res->groups.size());
  124. assert_equals("Hell", builder.str(res->groups[0]));
  125. assert_equals("el", builder.str(res->groups[1]));
  126. // Verify position is after the match
  127. assert_equals<size_t>(4, builder.pos());
  128. assert_equals("o,", builder.consume_rest());
  129. }
  130. {
  131. // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
  132. common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
  133. assert_throws([&]() {
  134. builder.try_consume_regex(common_regex("Hello, world!"));
  135. }, "^Hello, world!$");
  136. }
  137. // Now regardless of the mode, we can tell these aren't a match.
  138. for (const auto is_partial : {false, true}) {
  139. common_chat_msg_parser builder("Hello,", is_partial, {});
  140. assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
  141. }
  142. for (const auto is_partial : {false, true}) {
  143. common_chat_msg_parser builder("Hello,", is_partial, {});
  144. assert_equals(false, builder.try_consume_literal("Oh"));
  145. }
  146. }
  147. const std::vector<std::string> barely_healable_jsons = {
  148. "{",
  149. "{\"",
  150. "{\"\\",
  151. "{\"n",
  152. "{\"name\"",
  153. "{\"name\":",
  154. "{\"name\":\"",
  155. "{\"name\":\"\\",
  156. "{\"name\":\"python",
  157. "{\"name\":\"python\\",
  158. "{\",",
  159. "{\":",
  160. "{\"[",
  161. "{\"]",
  162. "{\"{",
  163. "{\"}",
  164. "{\"1",
  165. "{\"name\":\",",
  166. "{\"name\":\":",
  167. "{\"name\":\"[",
  168. "{\"name\":\"]",
  169. "{\"name\":\"{",
  170. "{\"name\":\"}",
  171. "{\"name\":\"1",
  172. };
  173. static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
  174. common_chat_msg_parser builder(input, is_partial, {});
  175. auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
  176. assert_equals(true, js.has_value());
  177. assert_equals(is_partial, js->is_partial);
  178. assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
  179. }
  180. static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
  181. common_chat_msg_parser builder(input, parse_as_partial, {});
  182. auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
  183. assert_equals(true, js.has_value());
  184. assert_equals(is_partial, js->is_partial);
  185. assert_equals(expected, js->value.dump());
  186. }
  187. static void test_json_with_dumped_args_no_args() {
  188. // Normal JSON, nothing to heal, nothing to dump
  189. test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
  190. // Full json is args
  191. test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
  192. // If the arguments are further down, don't heal partial content.
  193. for (const auto & src : barely_healable_jsons) {
  194. test(src, true, {{"arguments"}}, {}, "{}");
  195. }
  196. // But heal content that isn't partial.
  197. test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
  198. }
  199. static void test_json_with_dumped_args() {
  200. // Partial content.
  201. test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
  202. test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
  203. test("{\"content\": ", true, {}, {{"content"}}, "{}");
  204. // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
  205. test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
  206. for (const auto & src : barely_healable_jsons) {
  207. test(src, true, {{}}, {}, src);
  208. }
  209. // Full JSON w/ args
  210. for (auto parse_as_partial : {true, false}) {
  211. test_with_args(
  212. R"({"name": "python", "args": {"arg1": 1}})",
  213. R"({"name":"python","args":"{\"arg1\":1}"})",
  214. parse_as_partial,
  215. /* is_partial= */ false
  216. );
  217. }
  218. // Partial JSON w/ partial args
  219. test_with_args(
  220. R"({"foo": "bar", "args": {")",
  221. R"({"foo":"bar","args":"{\""})"
  222. );
  223. // Partial args broken in object key
  224. test_with_args(
  225. R"({"foo": "bar", "args": {"ar)",
  226. R"({"foo":"bar","args":"{\"ar"})"
  227. );
  228. // Partial args broken after object key
  229. test_with_args(
  230. R"({"foo": "bar", "args": {"arg1")",
  231. R"({"foo":"bar","args":"{\"arg1\""})"
  232. );
  233. // Partial args broken before object value
  234. test_with_args(
  235. R"({"foo": "bar", "args": {"arg1":)",
  236. R"({"foo":"bar","args":"{\"arg1\":"})"
  237. );
  238. // Partial args broken before object value (space)
  239. test_with_args(
  240. R"({"foo": "bar", "args": {"arg1": )",
  241. R"({"foo":"bar","args":"{\"arg1\":"})"
  242. );
  243. // Partial args broken in object value that may not be complete (int)
  244. test_with_args(
  245. R"({"foo": "bar", "args": {"arg1": 1)",
  246. R"({"foo":"bar","args":"{\"arg1\":"})"
  247. );
  248. // Partial args broken in object value that is complete (int)
  249. test_with_args(
  250. R"({"foo": "bar", "args": {"arg1": 1 )",
  251. R"({"foo":"bar","args":"{\"arg1\":1"})"
  252. );
  253. // Partial args broken in object value that is incomplete (string)
  254. test_with_args(
  255. R"({"foo": "bar", "args": {"arg1": ")",
  256. R"({"foo":"bar","args":"{\"arg1\":\""})"
  257. );
  258. // Partial args broken in object value that is complete (string)
  259. test_with_args(
  260. R"({"foo": "bar", "args": {"arg1": "1")",
  261. R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
  262. );
  263. // Partial args broken on array opening
  264. test_with_args(
  265. R"({"foo": "bar", "args": [)",
  266. R"({"foo":"bar","args":"["})"
  267. );
  268. // Partial args broken on array value that is incomplete (int)
  269. test_with_args(
  270. R"({"foo": "bar", "args": [1)",
  271. R"({"foo":"bar","args":"["})"
  272. );
  273. // Partial args broken on array value that is complete (int)
  274. test_with_args(
  275. R"({"foo": "bar", "args": [1 )",
  276. R"({"foo":"bar","args":"[1"})"
  277. );
  278. // Partial args broken on array value that is complete (string)
  279. test_with_args(
  280. R"({"foo": "bar", "args": ["1")",
  281. R"({"foo":"bar","args":"[\"1\""})"
  282. );
  283. // Partial args broken after array value
  284. test_with_args(
  285. R"({"foo": "bar", "args": [1,)",
  286. R"({"foo":"bar","args":"[1,"})"
  287. );
  288. // Partial args broken on nested array
  289. test_with_args(
  290. R"({"foo": "bar", "args": {"arg1": [)",
  291. R"({"foo":"bar","args":"{\"arg1\":["})"
  292. );
  293. }
  294. static void test_positions() {
  295. {
  296. common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
  297. assert_equals<size_t>(0, builder.pos());
  298. assert_throws([&]() { builder.move_to(100); });
  299. assert_equals<size_t>(0, builder.pos());
  300. assert_throws([&]() { builder.move_back(1); });
  301. assert_equals<size_t>(0, builder.pos());
  302. builder.move_to(8);
  303. assert_equals<size_t>(8, builder.pos());
  304. builder.move_back(1);
  305. assert_equals<size_t>(7, builder.pos());
  306. assert_equals("world!", builder.consume_rest());
  307. builder.move_to(0);
  308. assert_equals<size_t>(0, builder.pos());
  309. assert_throws([&]() { builder.finish(); });
  310. assert_equals<size_t>(0, builder.pos());
  311. builder.move_to(builder.input().size());
  312. builder.finish();
  313. }
  314. {
  315. common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
  316. builder.move_to(builder.input().size());
  317. assert_equals<size_t>(builder.input().size(), builder.pos());
  318. builder.finish();
  319. }
  320. }
  321. int main() {
  322. test_positions();
  323. test_json_with_dumped_args_no_args();
  324. test_json_with_dumped_args();
  325. test_reasoning();
  326. test_regex();
  327. std::cout << "All tests passed!\n";
  328. return 0;
  329. }