test-chat-peg-parser.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. #include <string>
  2. #include <iostream>
  3. #include <numeric>
  4. #include "chat-parser.h"
  5. #include "chat-peg-parser.h"
  6. #include "chat.h"
  7. #include "common.h"
  8. #include "json-schema-to-grammar.h"
  9. #include "peg-parser.h"
  10. #include "peg-parser/testing.h"
  11. #include "peg-parser/simple-tokenize.h"
  12. #include "nlohmann/json.hpp"
  13. using json = nlohmann::ordered_json;
  14. static json create_tools();
  15. static void test_example_native(testing & t);
  16. static void test_example_qwen3_coder(testing & t);
  17. static void test_command7_parser_compare(testing & t);
  18. int main(int argc, char *argv[]) {
  19. testing t(std::cout);
  20. if (argc >= 2) {
  21. t.set_filter(argv[1]);
  22. }
  23. const char * verbose = getenv("LLAMA_TEST_VERBOSE");
  24. if (verbose) {
  25. t.verbose = std::string(verbose) == "1";
  26. }
  27. t.test("native", test_example_native);
  28. t.test("qwen3 coder", test_example_qwen3_coder);
  29. t.test("comparison", test_command7_parser_compare);
  30. return t.summary();
  31. }
  32. static json create_tools() {
  33. json tools = json::array();
  34. json tool_weather = {
  35. {"type", "function"},
  36. {"function", {
  37. {"name", "get_current_weather"},
  38. {"description", "Get the current weather in a given location"},
  39. {"parameters", {
  40. {"type", "object"},
  41. {"properties", {
  42. {"location", {
  43. {"type", "string"},
  44. {"description", "The city and state, e.g. San Francisco, CA"}
  45. }},
  46. {"unit", {
  47. {"type", "string"},
  48. {"enum", {"celsius", "fahrenheit"}},
  49. {"description", "The temperature unit to use. Infer this from the users location."}
  50. }}
  51. }},
  52. {"required", {"location", "unit"}},
  53. }},
  54. }}
  55. };
  56. tools.push_back(tool_weather);
  57. json tool_forecast = {
  58. {"type", "function"},
  59. {"function", {
  60. {"name", "get_forecast"},
  61. {"description", "Get the weather forecast for a given location"},
  62. {"parameters", {
  63. {"type", "object"},
  64. {"properties", {
  65. {"location", {
  66. {"type", "string"},
  67. {"description", "The city and state, e.g. San Francisco, CA"}
  68. }},
  69. {"unit", {
  70. {"type", "string"},
  71. {"enum", {"celsius", "fahrenheit"}},
  72. {"description", "The temperature unit to use. Infer this from the users location."}
  73. }},
  74. {"days", {
  75. {"type", "integer"},
  76. {"description", "Number of days to forecast (1-10)"},
  77. {"minimum", 1},
  78. {"maximum", 10}
  79. }}
  80. }},
  81. {"required", {"location", "unit"}},
  82. }},
  83. }}
  84. };
  85. tools.push_back(tool_forecast);
  86. json tool_search = {
  87. {"type", "function"},
  88. {"function", {
  89. {"name", "search_knowledge_base"},
  90. {"description", "Search the internal technical documentation knowledge base."},
  91. {"parameters", {
  92. {"type", "object"},
  93. {"properties", {
  94. {"query", {
  95. {"type", "string"},
  96. {"description", "The search query string."}
  97. }},
  98. {"max_results", {
  99. {"type", "integer"},
  100. {"description", "The maximum number of results to return."},
  101. {"default", 5}
  102. }},
  103. {"category", {
  104. {"type", "string"},
  105. {"enum", {"api", "troubleshooting", "billing", "general"}},
  106. {"description", "Filter search by specific category."}
  107. }}
  108. }},
  109. {"required", {"query", "category"}},
  110. {"additionalProperties", false}
  111. }},
  112. {"strict", true}
  113. }}
  114. };
  115. tools.push_back(tool_search);
  116. return tools;
  117. }
  118. struct tool_argument {
  119. std::string name;
  120. std::string type;
  121. bool is_required;
  122. json schema;
  123. };
  124. struct tool_definition {
  125. std::string name;
  126. std::vector<tool_argument> arguments;
  127. json schema;
  128. };
  129. // Test fictitious model output that emits arguments as JSON.
  130. static void test_example_native(testing & t) {
  131. struct test_case {
  132. // Parameters
  133. std::string name;
  134. json tools;
  135. common_chat_tool_choice tool_choice;
  136. common_reasoning_format reasoning_format;
  137. json json_schema;
  138. bool parallel_tool_calls;
  139. bool thinking_forced_open;
  140. std::string input;
  141. // Expect
  142. std::string expect_reasoning;
  143. std::string expect_content;
  144. std::vector<common_chat_tool_call> expect_tool_calls;
  145. };
  146. auto build_parser = [](const test_case & tc) {
  147. return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
  148. auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
  149. auto reasoning = p.eps();
  150. if (tc.thinking_forced_open) {
  151. // If thinking is forced open, expect a closing tag
  152. reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
  153. } else {
  154. // Otherwise, optionally accept thinking wrapped in tags
  155. reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
  156. }
  157. // tool calling parser
  158. if (tc.tools.is_array() && !tc.tools.empty()) {
  159. auto tools = p.choice();
  160. for (const auto & tool : tc.tools) {
  161. const auto & function = tool.at("function");
  162. std::string name = function.at("name");
  163. const auto & schema = function.at("parameters");
  164. auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
  165. auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
  166. tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}");
  167. };
  168. auto parallel_calls = p.eps();
  169. if (tc.parallel_tool_calls) {
  170. parallel_calls = p.zero_or_more("," << tools);
  171. }
  172. auto tool_call = p.trigger_rule("tool-call",
  173. p.sequence({
  174. p.literal("<tool_call>["),
  175. tools,
  176. parallel_calls,
  177. p.literal("]</tool_call>")
  178. })
  179. );
  180. return p.sequence({
  181. (reasoning_in_content ? p.eps() : reasoning),
  182. p.content(p.until("<tool_call>")),
  183. p.optional(p.space() + tool_call),
  184. p.space(),
  185. p.end()
  186. });
  187. }
  188. // response_format parser
  189. if (tc.json_schema.is_object() && !tc.json_schema.empty()) {
  190. return p.sequence({
  191. (reasoning_in_content ? p.eps() : reasoning),
  192. p.content(p.schema(p.json(), "response-output", tc.json_schema)),
  193. p.space(),
  194. p.end()
  195. });
  196. }
  197. // Content-only parser
  198. return p.sequence({
  199. (reasoning_in_content ? p.eps() : reasoning),
  200. p.content(p.rest()),
  201. p.end()
  202. });
  203. });
  204. };
  205. std::vector<test_case> test_cases = std::vector<test_case>{
  206. {
  207. /* .name = */ "content with thinking_forced_open = false",
  208. /* .tools = */ {},
  209. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  210. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  211. /* .json_schema = */ {},
  212. /* .parallel_tool_calls = */ false,
  213. /* .thinking_forced_open = */ false,
  214. /* .input = */ (
  215. "<think>The user said hello, I must say hello back</think>\nHello"
  216. ),
  217. /* .expect_reasoning = */ "The user said hello, I must say hello back",
  218. /* .expect_content = */ "Hello",
  219. /* .expect_tool_calls = */ {},
  220. },
  221. {
  222. /* .name = */ "content with thinking_forced_open = false and no reasoning",
  223. /* .tools = */ {},
  224. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  225. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  226. /* .json_schema = */ {},
  227. /* .parallel_tool_calls = */ false,
  228. /* .thinking_forced_open = */ false,
  229. /* .input = */ (
  230. "Hello"
  231. ),
  232. /* .expect_reasoning = */ "",
  233. /* .expect_content = */ "Hello",
  234. /* .expect_tool_calls = */ {},
  235. },
  236. {
  237. /* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
  238. /* .tools = */ {},
  239. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  240. /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
  241. /* .json_schema = */ {},
  242. /* .parallel_tool_calls = */ false,
  243. /* .thinking_forced_open = */ true,
  244. /* .input = */ (
  245. "<think>The user said hello, I must say hello back</think>\nHello"
  246. ),
  247. /* .expect_reasoning = */ "",
  248. /* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
  249. /* .expect_tool_calls = */ {},
  250. },
  251. {
  252. /* .name = */ "content with thinking_forced_open = true",
  253. /* .tools = */ {},
  254. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  255. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  256. /* .json_schema = */ {},
  257. /* .parallel_tool_calls = */ false,
  258. /* .thinking_forced_open = */ true,
  259. /* .input = */ (
  260. "The user said hello, I must say hello back</think>\nHello"
  261. ),
  262. /* .expect_reasoning = */ "The user said hello, I must say hello back",
  263. /* .expect_content = */ "Hello",
  264. /* .expect_tool_calls = */ {},
  265. },
  266. {
  267. /* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
  268. /* .tools = */ {},
  269. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  270. /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
  271. /* .json_schema = */ {},
  272. /* .parallel_tool_calls = */ false,
  273. /* .thinking_forced_open = */ true,
  274. /* .input = */ (
  275. "The user said hello, I must say hello back</think>\nHello"
  276. ),
  277. /* .expect_reasoning = */ "",
  278. /* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
  279. /* .expect_tool_calls = */ {},
  280. },
  281. {
  282. /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
  283. /* .tools = */ create_tools(),
  284. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
  285. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  286. /* .json_schema = */ {},
  287. /* .parallel_tool_calls = */ false,
  288. /* .thinking_forced_open = */ true,
  289. /* .input = */ (
  290. "I must get the weather in New York</think>\n"
  291. "<tool_call>["
  292. R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
  293. "]</tool_call>"
  294. ),
  295. /* .expect_reasoning = */ "I must get the weather in New York",
  296. /* .expect_content = */ "",
  297. /* .expect_tool_calls = */ {{
  298. /* .name = */ "get_current_weather",
  299. /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
  300. /* .id = */ "",
  301. }},
  302. },
  303. {
  304. /* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
  305. /* .tools = */ create_tools(),
  306. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
  307. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  308. /* .json_schema = */ {},
  309. /* .parallel_tool_calls = */ true,
  310. /* .thinking_forced_open = */ true,
  311. /* .input = */ (
  312. "I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you."
  313. "<tool_call>["
  314. R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
  315. ", "
  316. R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})"
  317. ", "
  318. R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})"
  319. ", "
  320. R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})"
  321. "]</tool_call>"
  322. ),
  323. /* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
  324. /* .expect_content = */ "Let me search that for you.",
  325. /* .expect_tool_calls = */ {{
  326. /* .name = */ "get_current_weather",
  327. /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
  328. /* .id = */ "",
  329. }, {
  330. /* .name = */ "get_current_weather",
  331. /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})",
  332. /* .id = */ "",
  333. }, {
  334. /* .name = */ "get_forecast",
  335. /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})",
  336. /* .id = */ "",
  337. }, {
  338. /* .name = */ "get_forecast",
  339. /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})",
  340. /* .id = */ "",
  341. }},
  342. },
  343. {
  344. /* .name = */ "response_format with thinking_forced_open = true",
  345. /* .tools = */ {},
  346. /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
  347. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  348. /* .json_schema = */ {
  349. {"type", "object"},
  350. {"properties", {
  351. {"invoice_number", {{"type", "string"}}},
  352. {"amount", {{"type", "number"}}},
  353. {"due_date", {{"type", "string"}}}
  354. }},
  355. {"required", {"invoice_number", "amount", "due_date"}}
  356. },
  357. /* .parallel_tool_calls = */ false,
  358. /* .thinking_forced_open = */ true,
  359. /* .input = */ (
  360. "I must produce the invoice in the requested format</think>\n"
  361. R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"
  362. ),
  363. /* .expect_reasoning = */ "I must produce the invoice in the requested format",
  364. /* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})",
  365. /* .expect_tool_calls = */ {},
  366. },
  367. };
  368. for (const auto & tc : test_cases) {
  369. t.test(tc.name, [&](testing & t) {
  370. auto parser = build_parser(tc);
  371. auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
  372. auto grammar = build_grammar([&](const common_grammar_builder & builder) {
  373. for (auto const & def : tc.tools) {
  374. auto function = def.at("function");
  375. auto parameters = function.at("parameters");
  376. builder.resolve_refs(parameters);
  377. };
  378. parser.build_grammar(builder, lazy);
  379. });
  380. t.log("Grammar:");
  381. for (auto const & line : string_split(grammar, "\n")) {
  382. t.log(line);
  383. }
  384. common_peg_parse_context ctx(tc.input, false);
  385. auto result = parser.parse(ctx);
  386. t.assert_true("success", result.success());
  387. common_chat_msg msg;
  388. auto mapper = common_chat_peg_native_mapper(msg);
  389. mapper.from_ast(ctx.ast, result);
  390. t.assert_equal("content equal", tc.expect_content, msg.content);
  391. t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content);
  392. t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size());
  393. for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) {
  394. t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name);
  395. t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments);
  396. }
  397. });
  398. }
  399. }
  400. static void test_example_qwen3_coder(testing & t) {
  401. auto tools = create_tools();
  402. auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
  403. auto content = p.rule("content", p.content(p.until("<tool_call>")));
  404. std::vector<common_peg_parser> tool_parsers;
  405. for (auto const & def : tools) {
  406. auto function = def.at("function");
  407. std::string name = function.at("name");
  408. auto parameters = function.at("parameters");
  409. auto properties = parameters.at("properties");
  410. std::set<std::string> required_properties;
  411. if (function.contains("required")) {
  412. function.at("required").get_to(required_properties);
  413. }
  414. std::vector<common_peg_parser> arg_parsers;
  415. for (const auto & [param_name, param_schema] : properties.items()) {
  416. bool is_required = required_properties.find(param_name) != required_properties.end();
  417. auto type = param_schema.value("type", "object");
  418. auto arg = p.tool_arg(p.sequence({
  419. p.tool_arg_open("<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">"),
  420. (type == "string" ?
  421. p.tool_arg_string_value(
  422. p.schema(
  423. p.until_one_of({
  424. "</parameter>\n<parameter=",
  425. "</parameter>\n</function>"
  426. }),
  427. "tool-" + name + "-arg-" + param_name + "-schema",
  428. param_schema,
  429. true
  430. )
  431. ) : p.tool_arg_json_value(
  432. p.schema(
  433. p.json(),
  434. "tool-" + name + "-arg-" + param_name + "-schema",
  435. param_schema
  436. )
  437. )
  438. ),
  439. p.tool_arg_close(
  440. "</parameter>\n" +
  441. p.peek(p.literal("<parameter=") | p.literal("</function>"))
  442. )
  443. }));
  444. arg_parsers.push_back(is_required ?
  445. p.rule("tool-" + name + "-arg-" + param_name, arg) :
  446. p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
  447. }
  448. tool_parsers.push_back(p.rule("tool-" + name,
  449. p.tool_open("<function=" + p.tool_name(p.literal(name)) + ">")
  450. << p.sequence(arg_parsers)
  451. << p.tool_close(p.literal("</function>"))
  452. ));
  453. };
  454. auto tool_call = p.trigger_rule("tool-call",
  455. "<tool_call>"
  456. << p.choice(tool_parsers)
  457. << "</tool_call>"
  458. );
  459. return content + p.zero_or_more(p.space() + tool_call) + p.end();
  460. });
  461. auto grammar = build_grammar([&](const common_grammar_builder & builder) {
  462. for (auto const & def : tools) {
  463. auto function = def.at("function");
  464. auto parameters = function.at("parameters");
  465. builder.resolve_refs(parameters);
  466. };
  467. parser.build_grammar(builder);
  468. });
  469. t.log("Grammar:");
  470. for (auto const & line : string_split(grammar, "\n")) {
  471. t.log(line);
  472. }
  473. t.test("incremental parsing", [&](testing &t) {
  474. std::string input =
  475. "Let me search the knowledge base for cat pictures."
  476. "<tool_call>\n"
  477. "<function=search_knowledge_base>\n"
  478. "<parameter=query>cat pictures</parameter>\n"
  479. "<parameter=category>general</parameter>\n"
  480. "</function>\n"
  481. "</tool_call>";
  482. std::vector<std::string> tokens = simple_tokenize(input);
  483. common_chat_msg prev;
  484. for (auto it = tokens.begin(); it != tokens.end(); it++) {
  485. std::string in = std::accumulate(tokens.begin(), it + 1, std::string());
  486. common_peg_parse_context ctx(in, it + 1 < tokens.end());
  487. auto result = parser.parse(ctx);
  488. if (!t.assert_equal("not fail", false, result.fail())) {
  489. t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
  490. }
  491. common_chat_msg msg;
  492. auto mapper = common_chat_peg_constructed_mapper(msg);
  493. mapper.from_ast(ctx.ast, result);
  494. //t.log("Input: " + input);
  495. t.log("===========================================");
  496. t.log("Iteration " + std::to_string(in.size()));
  497. t.log("Reasoning: " + msg.reasoning_content);
  498. t.log("Content : " + msg.content);
  499. for (const auto & tc : msg.tool_calls) {
  500. t.log("Tool name: " + tc.name);
  501. t.log("Tool args: " + tc.arguments);
  502. }
  503. try {
  504. // This shouldn't emit any runtime errors
  505. auto diffs = common_chat_msg_diff::compute_diffs(prev, msg);
  506. } catch(const std::exception & e) {
  507. t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
  508. t.assert_true(std::string("failed with ") + e.what(), false);
  509. }
  510. prev = msg;
  511. }
  512. });
  513. }
  514. void test_command7_parser_compare(testing & t) {
  515. auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) {
  516. auto thinking = p.reasoning_block(
  517. "<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>");
  518. auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>";
  519. auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\"")));
  520. auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\"")));
  521. auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json()));
  522. auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args);
  523. auto tool_call = p.rule("tool-call", p.tool(
  524. p.tool_open(p.literal("{"))
  525. << tool_call_fields
  526. << p.zero_or_more( p.literal(",") << tool_call_fields)
  527. << p.tool_close(p.literal("}"))
  528. ));
  529. auto tool_calls = p.rule("tool-calls",
  530. "<|START_ACTION|>"
  531. << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]")
  532. << "<|END_ACTION|>");
  533. return p.optional(thinking) << (tool_calls | response) + p.end();
  534. });
  535. auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) {
  536. common_peg_parse_context ctx(input, is_partial);
  537. auto result = p.parse(ctx);
  538. common_chat_msg msg;
  539. auto mapper = common_chat_peg_native_mapper(msg);
  540. mapper.from_ast(ctx.ast, result);
  541. if (print_results) {
  542. std::cout << "== Parsed (new) ==\n";
  543. std::cout << "=== Reasoning ===\n";
  544. std::cout << msg.reasoning_content << "\n";
  545. std::cout << "\n\n=== Content ===\n";
  546. std::cout << msg.content << "\n";
  547. std::cout << "\n\n=== Tool Calls ===\n";
  548. for (const auto & tc : msg.tool_calls) {
  549. std::cout << "id: " << tc.id << "\n";
  550. std::cout << "name: " << tc.name << "\n";
  551. std::cout << "args: " << tc.arguments << "\n";
  552. }
  553. }
  554. };
  555. auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
  556. // Original common_chat_combinator_parser taken from chat.cpp
  557. common_chat_msg_parser builder(
  558. input,
  559. /* .is_partial = */ need_more_input,
  560. {
  561. /* .format = */ COMMON_CHAT_FORMAT_GENERIC,
  562. /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
  563. /* .reasoning_in_content = */ false,
  564. /* .thinking_forced_open = */ false,
  565. }
  566. );
  567. builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
  568. static const common_regex start_action_regex("<\\|START_ACTION\\|>");
  569. static const common_regex end_action_regex("<\\|END_ACTION\\|>");
  570. static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
  571. static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
  572. if (auto res = builder.try_find_regex(start_action_regex)) {
  573. // If we didn't extract thoughts, prelude includes them.
  574. auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } });
  575. for (const auto & tool_call : tool_calls.value) {
  576. std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
  577. std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
  578. std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
  579. if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
  580. throw common_chat_msg_partial_exception("incomplete tool call");
  581. }
  582. }
  583. if (tool_calls.is_partial) {
  584. throw common_chat_msg_partial_exception("incomplete tool call");
  585. }
  586. builder.consume_regex(end_action_regex);
  587. } else if (auto res = builder.try_find_regex(start_response_regex)) {
  588. if (!builder.try_find_regex(end_response_regex)) {
  589. builder.add_content(builder.consume_rest());
  590. throw common_chat_msg_partial_exception(end_response_regex.str());
  591. }
  592. } else {
  593. builder.add_content(builder.consume_rest());
  594. }
  595. if (print_results) {
  596. std::cout << "== Parsed (legacy) ==\n";
  597. std::cout << "=== Reasoning ===\n";
  598. std::cout << builder.result().reasoning_content << "\n";
  599. std::cout << "\n\n=== Content ===\n";
  600. std::cout << builder.result().content << "\n";
  601. std::cout << "\n\n=== Tool Calls ===\n";
  602. for (const auto & tc : builder.result().tool_calls) {
  603. std::cout << "id: " << tc.id << "\n";
  604. std::cout << "name: " << tc.name << "\n";
  605. std::cout << "args: " << tc.arguments << "\n";
  606. }
  607. }
  608. };
  609. std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a "
  610. "budget of $4000 for a two-week stay, we need to:\n\n"
  611. "1. Identify key historical sites and modern attractions in Japan.\n"
  612. "2. Find affordable accommodation options that provide a balance between comfort and cost.\n"
  613. "3. Determine the best modes of transportation for getting around Japan.\n"
  614. "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without "
  615. "overspending.\n"
  616. "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees "
  617. "to attractions.";
  618. std::vector<std::tuple<std::string, std::string, nlohmann::json>> tool_calls = {{
  619. "call_0",
  620. "plan_trip",
  621. nlohmann::json::parse(R"({
  622. "destination": "Japan",
  623. "duration": 14,
  624. "budget": 4000,
  625. "interests": ["historical sites", "modern attractions"],
  626. "accommodation_preferences": "affordable",
  627. "transportation_preferences": "efficient",
  628. "meal_preferences": "local cuisine"
  629. })")
  630. }};
  631. std::vector<std::string> tokens;
  632. // Build tokens
  633. if (!reasoning.empty()) {
  634. auto tokenized = simple_tokenize(reasoning);
  635. tokens.emplace_back("<|START_THINKING|>");
  636. tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
  637. tokens.emplace_back("<|END_THINKING|>");
  638. }
  639. if (!tool_calls.empty()) {
  640. tokens.emplace_back("<|START_ACTION|>");
  641. auto json = nlohmann::json::array();
  642. for (const auto & tc : tool_calls) {
  643. auto tc_json = nlohmann::json::object();
  644. tc_json["tool_call_id"] = std::get<0>(tc);
  645. tc_json["tool_name"] = std::get<1>(tc);
  646. tc_json["parameters"] = std::get<2>(tc);
  647. json.push_back(tc_json);
  648. }
  649. auto tokenized = simple_tokenize(json.dump(-1, ' ', true));
  650. tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
  651. tokens.emplace_back("<|END_ACTION|>");
  652. }
  653. std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string());
  654. // Run tests
  655. t.test("legacy_parse", [&](testing & /* t */) {
  656. test_legacy(input, false, false);
  657. });
  658. t.test("current_parse", [&](testing & /* t */) {
  659. test_current(parser, input, false, false);
  660. });
  661. // Run benchmarks
  662. t.bench("legacy_parse_benchmark complete", [&]() {
  663. test_legacy(input, false, false);
  664. });
  665. t.bench("legacy_parse_benchmark incremental", [&]() {
  666. std::string in;
  667. for (auto i = 0u; i < tokens.size(); i++) {
  668. in += tokens[i];
  669. try {
  670. test_legacy(in, i + 1 < tokens.size(), false);
  671. } catch (common_chat_msg_partial_exception & /* e */) {
  672. // Do nothing, this is expected
  673. }
  674. }
  675. }, 20);
  676. t.bench("current_parse_benchmark complete", [&]() {
  677. test_current(parser, input, false, false);
  678. }, 100);
  679. t.bench("current_parse_benchmark incremental", [&]() {
  680. std::string in;
  681. for (auto i = 0u; i < tokens.size(); i++) {
  682. in += tokens[i];
  683. test_current(parser, in, i + 1 < tokens.size(), false);
  684. }
  685. }, 20);
  686. }