chat-parser-xml-toolcall.cpp 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. #include "chat.h"
  2. #include "chat-parser.h"
  3. #include "common.h"
  4. #include "json-partial.h"
  5. #include "json-schema-to-grammar.h"
  6. #include "log.h"
  7. #include "regex-partial.h"
  8. using json = nlohmann::ordered_json;
  9. class xml_toolcall_syntax_exception : public std::runtime_error {
  10. public:
  11. xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
  12. };
  13. template<typename T>
  14. inline void sort_uniq(std::vector<T> &vec) {
  15. std::sort(vec.begin(), vec.end());
  16. vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
  17. }
  18. template<typename T>
  19. inline bool all_space(const T &str) {
  20. return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
  21. }
  22. static size_t utf8_truncate_safe(const std::string_view s) {
  23. size_t len = s.size();
  24. if (len == 0) return 0;
  25. size_t i = len;
  26. for (size_t back = 0; back < 4 && i > 0; ++back) {
  27. --i;
  28. unsigned char c = s[i];
  29. if ((c & 0x80) == 0) {
  30. return len;
  31. } else if ((c & 0xC0) == 0xC0) {
  32. size_t expected_len = 0;
  33. if ((c & 0xE0) == 0xC0) expected_len = 2;
  34. else if ((c & 0xF0) == 0xE0) expected_len = 3;
  35. else if ((c & 0xF8) == 0xF0) expected_len = 4;
  36. else return i;
  37. if (len - i >= expected_len) {
  38. return len;
  39. } else {
  40. return i;
  41. }
  42. }
  43. }
  44. return len - std::min(len, size_t(3));
  45. }
  46. inline void utf8_truncate_safe_resize(std::string &s) {
  47. s.resize(utf8_truncate_safe(s));
  48. }
  49. inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
  50. return s.substr(0, utf8_truncate_safe(s));
  51. }
  52. static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
  53. if (literal1.size() == 0) return builder.try_find_literal(literal2);
  54. const auto saved_pos = builder.pos();
  55. while (auto res = builder.try_find_literal(literal1)) {
  56. builder.consume_spaces();
  57. const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
  58. if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
  59. if (res->prelude.size() != res->groups[0].begin - saved_pos) {
  60. res->prelude = builder.str({saved_pos, res->groups[0].begin});
  61. }
  62. builder.move_to(builder.pos() + match_len);
  63. res->groups[0].end = builder.pos();
  64. GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
  65. return res;
  66. }
  67. builder.move_to(res->groups[0].begin + 1);
  68. }
  69. builder.move_to(saved_pos);
  70. return std::nullopt;
  71. }
  72. /**
  73. * make a GBNF that accept any strings except those containing any of the forbidden strings.
  74. */
  75. std::string make_gbnf_excluding(std::vector<std::string> forbids) {
  76. constexpr auto charclass_escape = [](unsigned char c) -> std::string {
  77. if (c == '\\' || c == ']' || c == '^' || c == '-') {
  78. std::string s = "\\";
  79. s.push_back((char)c);
  80. return s;
  81. }
  82. if (isprint(c)) {
  83. return std::string(1, (char)c);
  84. }
  85. char buf[16];
  86. snprintf(buf, 15, "\\x%02X", c);
  87. return std::string(buf);
  88. };
  89. constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
  90. std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
  91. int i = l;
  92. while (i < r) {
  93. const std::string &s = forbids[i];
  94. if ((int)s.size() == depth) {
  95. ++i;
  96. continue;
  97. }
  98. unsigned char c = (unsigned char)s[depth];
  99. int j = i;
  100. while (j < r && (int)forbids[j].size() > depth &&
  101. (unsigned char)forbids[j][depth] == c) {
  102. ++j;
  103. }
  104. children.push_back({c, {i, j}});
  105. i = j;
  106. }
  107. std::vector<std::string> alts;
  108. if (!children.empty()) {
  109. std::string cls;
  110. for (auto &ch : children) cls += charclass_escape(ch.first);
  111. alts.push_back(std::string("[^") + cls + "]");
  112. }
  113. for (auto &ch : children) {
  114. std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
  115. if (!childExpr.empty()) {
  116. std::string quoted_ch = "\"";
  117. if (ch.first == '\\') quoted_ch += "\\\\";
  118. else if (ch.first == '"') quoted_ch += "\\\"";
  119. else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
  120. else {
  121. char buf[16];
  122. snprintf(buf, 15, "\\x%02X", ch.first);
  123. quoted_ch += buf;
  124. }
  125. quoted_ch += "\"";
  126. std::string branch = quoted_ch + std::string(" ") + childExpr;
  127. alts.push_back(branch);
  128. }
  129. }
  130. if (alts.empty()) return "";
  131. std::ostringstream oss;
  132. oss << "( ";
  133. for (size_t k = 0; k < alts.size(); ++k) {
  134. if (k) oss << " | ";
  135. oss << alts[k];
  136. }
  137. oss << " )";
  138. return oss.str();
  139. };
  140. if (forbids.empty()) return "( . )*";
  141. sort(forbids.begin(), forbids.end());
  142. std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
  143. if (expr.empty()) {
  144. std::string cls;
  145. for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
  146. expr = std::string("( [^") + cls + "] )";
  147. }
  148. if (forbids.size() == 1)
  149. return expr + "*";
  150. else
  151. return std::string("( ") + expr + " )*";
  152. }
  153. /**
  154. * Build grammar for xml-style tool call
  155. * form.scope_start and form.scope_end can be empty.
  156. * Requires data.format for model-specific hacks.
  157. */
  158. void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
  159. GGML_ASSERT(!form.tool_start.empty());
  160. GGML_ASSERT(!form.tool_sep.empty());
  161. GGML_ASSERT(!form.key_start.empty());
  162. GGML_ASSERT(!form.val_end.empty());
  163. GGML_ASSERT(!form.tool_end.empty());
  164. std::string key_val_sep = form.key_val_sep;
  165. if (form.key_val_sep2) {
  166. key_val_sep += "\n";
  167. key_val_sep += *form.key_val_sep2;
  168. }
  169. GGML_ASSERT(!key_val_sep.empty());
  170. if (tools.is_array() && !tools.empty()) {
  171. data.grammar = build_grammar([&](const common_grammar_builder &builder) {
  172. auto string_arg_val = form.last_val_end ?
  173. builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
  174. builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
  175. std::vector<std::string> tool_rules;
  176. for (const auto & tool : tools) {
  177. if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
  178. LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
  179. continue;
  180. }
  181. const auto & function = tool.at("function");
  182. if (!function.contains("name") || !function.at("name").is_string()) {
  183. LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
  184. continue;
  185. }
  186. if (!function.contains("parameters") || !function.at("parameters").is_object()) {
  187. LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
  188. continue;
  189. }
  190. std::string name = function.at("name");
  191. auto parameters = function.at("parameters");
  192. builder.resolve_refs(parameters);
  193. struct parameter_rule {
  194. std::string symbol_name;
  195. bool is_required;
  196. };
  197. std::vector<parameter_rule> arg_rules;
  198. if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
  199. LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
  200. continue;
  201. } else {
  202. std::vector<std::string> requiredParameters;
  203. if (parameters.contains("required")) {
  204. try { parameters.at("required").get_to(requiredParameters); }
  205. catch (const std::runtime_error&) {
  206. LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
  207. }
  208. }
  209. sort_uniq(requiredParameters);
  210. for (const auto & [key, value] : parameters.at("properties").items()) {
  211. std::string quoted_key = key;
  212. bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
  213. if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
  214. quoted_key = gbnf_format_literal(key);
  215. quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
  216. }
  217. arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
  218. gbnf_format_literal(form.key_start) + " " +
  219. gbnf_format_literal(quoted_key) + " " +
  220. gbnf_format_literal(key_val_sep) + " " +
  221. ((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
  222. (form.raw_argval ?
  223. string_arg_val :
  224. "( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
  225. ) :
  226. builder.add_schema(name + "-arg-" + key, value)
  227. )
  228. ), required});
  229. }
  230. }
  231. auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
  232. decltype(next_arg_with_sep) next_arg = "\"\"";
  233. for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
  234. std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
  235. next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
  236. include_this_arg : "( " + include_this_arg + " ) | " + next_arg
  237. );
  238. include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
  239. next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
  240. include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
  241. );
  242. }
  243. std::string quoted_name = name;
  244. if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
  245. quoted_name = gbnf_format_literal(name);
  246. quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
  247. }
  248. quoted_name = gbnf_format_literal(quoted_name);
  249. // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
  250. if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
  251. quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
  252. }
  253. tool_rules.push_back(builder.add_rule(name + "-call",
  254. gbnf_format_literal(form.tool_start) + " " +
  255. quoted_name + " " +
  256. gbnf_format_literal(form.tool_sep) + " " +
  257. next_arg
  258. ));
  259. }
  260. auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
  261. auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
  262. auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
  263. auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
  264. builder.add_rule("root",
  265. (form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
  266. tool_call_multiple_with_end + "?" +
  267. (form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
  268. );
  269. });
  270. // grammar trigger for tool call
  271. data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
  272. }
  273. }
  274. /**
  275. * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
  276. * Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
  277. * form.scope_start, form.tool_sep and form.scope_end can be empty.
  278. */
  279. inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
  280. GGML_ASSERT(!form.tool_start.empty());
  281. GGML_ASSERT(!form.key_start.empty());
  282. GGML_ASSERT(!form.key_val_sep.empty());
  283. GGML_ASSERT(!form.val_end.empty());
  284. GGML_ASSERT(!form.tool_end.empty());
  285. // Helper to choose return false or throw error
  286. constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
  287. LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
  288. if (recovery) {
  289. builder.move_to(start_pos);
  290. return false;
  291. } else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
  292. };
  293. // Drop substring from needle to end from a JSON
  294. constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
  295. auto pos = json_str.rfind(needle);
  296. if (pos == std::string::npos) {
  297. return false;
  298. }
  299. for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
  300. unsigned char ch = static_cast<unsigned char>(json_str[i]);
  301. if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
  302. return false;
  303. }
  304. }
  305. if (pos != 0 && json_str[pos - 1] == '"') {
  306. --pos;
  307. }
  308. json_str.resize(pos);
  309. return true;
  310. };
  311. // Helper to generate a partial argument JSON
  312. constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
  313. auto rest = builder.consume_rest();
  314. utf8_truncate_safe_resize(rest);
  315. set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
  316. auto tool_str = arguments.dump();
  317. if (partial_json(tool_str)) {
  318. if (builder.add_tool_call(function_name, "", tool_str)) {
  319. return;
  320. }
  321. }
  322. LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
  323. };
  324. // Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
  325. constexpr auto try_find_close = [](
  326. common_chat_msg_parser & builder,
  327. const std::string & end,
  328. const std::optional<std::string> & alt_end,
  329. const std::string & end_next,
  330. const std::optional<std::string> & alt_end_next
  331. ) {
  332. auto saved_pos = builder.pos();
  333. auto tc = builder.try_find_literal(end);
  334. auto val_end_size = end.size();
  335. if (alt_end) {
  336. auto pos_1 = builder.pos();
  337. builder.move_to(saved_pos);
  338. auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
  339. if (alt_end_next) {
  340. builder.move_to(saved_pos);
  341. auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
  342. if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
  343. tc2 = tc3;
  344. }
  345. }
  346. if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
  347. tc = tc2;
  348. tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
  349. builder.move_to(tc->groups[0].end);
  350. val_end_size = alt_end->size();
  351. } else {
  352. builder.move_to(pos_1);
  353. }
  354. }
  355. return std::make_pair(val_end_size, tc);
  356. };
  357. // Helper to find a val_end or last_val_end, returns matched pattern size
  358. const auto try_find_val_end = [try_find_close, &builder, &form]() {
  359. return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
  360. };
  361. // Helper to find a tool_end or last_tool_end, returns matched pattern size
  362. const auto try_find_tool_end = [try_find_close, &builder, &form]() {
  363. return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
  364. };
  365. bool recovery = true;
  366. const auto start_pos = builder.pos();
  367. if (!all_space(form.scope_start)) {
  368. if (auto tc = builder.try_find_literal(form.scope_start)) {
  369. if (all_space(tc->prelude)) {
  370. if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
  371. throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
  372. } else {
  373. builder.move_to(start_pos);
  374. return false;
  375. }
  376. } else return false;
  377. }
  378. while (auto tc = builder.try_find_literal(form.tool_start)) {
  379. if (!all_space(tc->prelude)) {
  380. LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
  381. gbnf_format_literal(form.tool_start).c_str(),
  382. gbnf_format_literal(tc->prelude).c_str()
  383. );
  384. builder.move_to(tc->groups[0].begin - tc->prelude.size());
  385. break;
  386. }
  387. // Find tool name
  388. auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
  389. if (!func_name) {
  390. auto [sz, tc] = try_find_tool_end();
  391. func_name = tc;
  392. }
  393. if (!func_name) {
  394. // Partial tool name not supported
  395. throw common_chat_msg_partial_exception("incomplete tool_call");
  396. }
  397. // If the model generate multiple tool call and the first tool call has no argument
  398. if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
  399. builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
  400. auto [sz, tc] = try_find_tool_end();
  401. func_name = tc;
  402. }
  403. // Parse tool name
  404. builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
  405. std::string function_name = string_strip(func_name->prelude);
  406. // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
  407. if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
  408. if (string_starts_with(function_name, "functions.")) {
  409. static const std::regex re(":\\d+$");
  410. if (std::regex_search(function_name, re)) {
  411. function_name = function_name.substr(10, function_name.rfind(":") - 10);
  412. }
  413. }
  414. }
  415. // Argument JSON
  416. json arguments = json::object();
  417. // Helper to generate a partial argument JSON
  418. const auto gen_partial_args = [&](auto set_partial_arg) {
  419. gen_partial_json(set_partial_arg, arguments, builder, function_name);
  420. };
  421. // Parse all arg_key/arg_value pairs
  422. while (auto tc = builder.try_find_literal(form.key_start)) {
  423. if (!all_space(tc->prelude)) {
  424. LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
  425. gbnf_format_literal(form.key_start).c_str(),
  426. gbnf_format_literal(tc->prelude).c_str()
  427. );
  428. builder.move_to(tc->groups[0].begin - tc->prelude.size());
  429. break;
  430. }
  431. if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
  432. auto tool_call_arg = arguments.dump();
  433. if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
  434. tool_call_arg.resize(tool_call_arg.size() - 1);
  435. }
  436. builder.add_tool_call(function_name, "", tool_call_arg);
  437. throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
  438. }
  439. // Parse arg_key
  440. auto key_res = builder.try_find_literal(form.key_val_sep);
  441. if (!key_res) {
  442. gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
  443. throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
  444. }
  445. if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
  446. gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
  447. throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
  448. }
  449. auto &key = key_res->prelude;
  450. recovery = false;
  451. // Parse arg_value
  452. if (form.key_val_sep2) {
  453. if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
  454. if (!all_space(tc->prelude)) {
  455. LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
  456. gbnf_format_literal(tc->prelude).c_str(),
  457. gbnf_format_literal(form.key_val_sep).c_str(),
  458. gbnf_format_literal(*form.key_val_sep2).c_str()
  459. );
  460. return return_error(builder, start_pos, false);
  461. }
  462. if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
  463. gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
  464. throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
  465. }
  466. } else {
  467. gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
  468. throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
  469. }
  470. }
  471. auto val_start = builder.pos();
  472. // Test if arg_val is a partial JSON
  473. std::optional<common_json> value_json = std::nullopt;
  474. if (!form.raw_argval || !*form.raw_argval) {
  475. try { value_json = builder.try_consume_json(); }
  476. catch (const std::runtime_error&) { builder.move_to(val_start); }
  477. // TODO: Delete this when json_partial adds top-level support for null/true/false
  478. if (builder.pos() == val_start) {
  479. const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
  480. builder.consume_spaces();
  481. std::string_view sv = utf8_truncate_safe_view(builder.input());
  482. sv.remove_prefix(builder.pos());
  483. std::string rest = "a";
  484. if (sv.size() < 6) rest = sv;
  485. if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
  486. value_json = {123, {"123", "123"}};
  487. builder.consume_rest();
  488. } else {
  489. builder.move_to(val_start);
  490. }
  491. }
  492. }
  493. // If it is a JSON and followed by </arg_value>, parse as json
  494. // cannot support streaming because it may be a plain text starting with JSON
  495. if (value_json) {
  496. auto json_end = builder.pos();
  497. builder.consume_spaces();
  498. if (builder.pos() == builder.input().size()) {
  499. if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
  500. arguments[key] = value_json->json;
  501. auto json_str = arguments.dump();
  502. if (!value_json->healing_marker.json_dump_marker.empty()) {
  503. GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
  504. json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
  505. } else {
  506. GGML_ASSERT(json_str.back() == '}');
  507. json_str.resize(json_str.size() - 1);
  508. }
  509. builder.add_tool_call(function_name, "", json_str);
  510. } else {
  511. gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
  512. }
  513. LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
  514. throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
  515. }
  516. builder.move_to(json_end);
  517. auto [val_end_size, tc] = try_find_val_end();
  518. if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
  519. if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
  520. gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
  521. LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
  522. throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
  523. } else arguments[key] = value_json->json;
  524. } else builder.move_to(val_start);
  525. }
  526. // If not, parse as plain text
  527. if (val_start == builder.pos()) {
  528. if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
  529. auto &value_str = value_plain->prelude;
  530. if (form.trim_raw_argval) value_str = string_strip(value_str);
  531. if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
  532. gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
  533. throw common_chat_msg_partial_exception(
  534. "Expected " + gbnf_format_literal(form.val_end) +
  535. " after " + gbnf_format_literal(form.key_val_sep) +
  536. (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
  537. );
  538. }
  539. arguments[key] = value_str;
  540. } else {
  541. if (form.trim_raw_argval) {
  542. gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
  543. } else {
  544. gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
  545. }
  546. throw common_chat_msg_partial_exception(
  547. "Expected " + gbnf_format_literal(form.val_end) +
  548. " after " + gbnf_format_literal(form.key_val_sep) +
  549. (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
  550. );
  551. }
  552. }
  553. }
  554. // Consume closing tag
  555. if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
  556. if (!all_space(tc->prelude)) {
  557. LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
  558. gbnf_format_literal(form.tool_end).c_str(),
  559. gbnf_format_literal(tc->prelude).c_str()
  560. );
  561. return return_error(builder, start_pos, recovery);
  562. }
  563. if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
  564. // Add the parsed tool call
  565. if (!builder.add_tool_call(function_name, "", arguments.dump())) {
  566. throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
  567. }
  568. recovery = false;
  569. continue;
  570. }
  571. }
  572. auto tool_call_arg = arguments.dump();
  573. if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
  574. tool_call_arg.resize(tool_call_arg.size() - 1);
  575. }
  576. builder.add_tool_call(function_name, "", tool_call_arg);
  577. throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
  578. }
  579. if (auto tc = builder.try_find_literal(form.scope_end)) {
  580. if (!all_space(tc->prelude)) {
  581. LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
  582. gbnf_format_literal(form.scope_end).c_str(),
  583. gbnf_format_literal(tc->prelude).c_str()
  584. );
  585. return return_error(builder, start_pos, recovery);
  586. }
  587. } else {
  588. if (all_space(form.scope_end)) return true;
  589. builder.consume_spaces();
  590. if (builder.pos() == builder.input().size())
  591. throw common_chat_msg_partial_exception("incomplete tool calls");
  592. LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
  593. gbnf_format_literal(form.scope_end).c_str(),
  594. gbnf_format_literal(builder.consume_rest()).c_str()
  595. );
  596. return return_error(builder, start_pos, recovery);
  597. }
  598. return true;
  599. }
  600. /**
  601. * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
  602. * May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
  603. * form.scope_start, form.tool_sep and form.scope_end can be empty.
  604. */
  605. bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
  606. auto pos = pos_;
  607. auto tsize = result_.tool_calls.size();
  608. try { return parse_xml_tool_calls(*this, form); }
  609. catch (const xml_toolcall_syntax_exception&) {}
  610. move_to(pos);
  611. result_.tool_calls.resize(tsize);
  612. return false;
  613. }
  614. /**
  615. * Parse content uses reasoning and XML-Style tool call
  616. * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
  617. */
  618. inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
  619. constexpr auto rstrip = [](std::string &s) {
  620. s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
  621. };
  622. // Erase substring from l to r, along with additional spaces nearby
  623. constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
  624. while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
  625. ++l;
  626. while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
  627. if (l < r) str[l] = '\n';
  628. if (l + 1 < r) str[l + 1] = '\n';
  629. if (l != 0) l += 2;
  630. str.erase(l, r - l);
  631. return l;
  632. };
  633. constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
  634. auto best_match = content.size();
  635. for (auto pattern: list) {
  636. if (pattern.size() == 0) continue;
  637. for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
  638. auto match_len = content.size() - match_idx;
  639. if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
  640. best_match = match_idx;
  641. }
  642. }
  643. }
  644. if (content.size() > best_match) {
  645. content.erase(best_match);
  646. }
  647. };
  648. const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
  649. return trim_suffix(content, {
  650. start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
  651. form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
  652. form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
  653. form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
  654. form.scope_end
  655. });
  656. };
  657. // Trim leading spaces without affecting keyword matching
  658. static const common_regex spaces_regex("\\s*");
  659. {
  660. auto tc = builder.consume_regex(spaces_regex);
  661. auto spaces = builder.str(tc.groups[0]);
  662. auto s1 = spaces.size();
  663. trim_potential_partial_word(spaces);
  664. auto s2 = spaces.size();
  665. builder.move_to(builder.pos() - (s1 - s2));
  666. }
  667. // Parse content
  668. bool reasoning_unclosed = builder.syntax().thinking_forced_open;
  669. std::string unclosed_reasoning_content("");
  670. for (;;) {
  671. auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
  672. std::string content;
  673. std::string tool_call_start;
  674. if (tc) {
  675. content = std::move(tc->prelude);
  676. tool_call_start = builder.str(tc->groups[0]);
  677. LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
  678. } else {
  679. content = builder.consume_rest();
  680. utf8_truncate_safe_resize(content);
  681. }
  682. // Handle unclosed think block
  683. if (reasoning_unclosed) {
  684. if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
  685. unclosed_reasoning_content += content;
  686. if (form.allow_toolcall_in_think) {
  687. builder.move_to(tc->groups[0].begin);
  688. if (!builder.try_consume_xml_tool_calls(form)) {
  689. unclosed_reasoning_content += tool_call_start;
  690. builder.move_to(tc->groups[0].end);
  691. }
  692. } else {
  693. unclosed_reasoning_content += tool_call_start;
  694. }
  695. continue;
  696. } else {
  697. reasoning_unclosed = false;
  698. std::string reasoning_content;
  699. if (pos == std::string::npos) {
  700. reasoning_content = std::move(content);
  701. } else {
  702. reasoning_content = content.substr(0, pos);
  703. content.erase(0, pos + end_think.size());
  704. }
  705. if (builder.pos() == builder.input().size() && all_space(content)) {
  706. rstrip(reasoning_content);
  707. trim_potential_partial_word(reasoning_content);
  708. rstrip(reasoning_content);
  709. if (reasoning_content.empty()) {
  710. rstrip(unclosed_reasoning_content);
  711. trim_potential_partial_word(unclosed_reasoning_content);
  712. rstrip(unclosed_reasoning_content);
  713. if (unclosed_reasoning_content.empty()) continue;
  714. }
  715. }
  716. if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
  717. builder.add_content(start_think);
  718. builder.add_content(unclosed_reasoning_content);
  719. builder.add_content(reasoning_content);
  720. if (builder.pos() != builder.input().size() || !all_space(content))
  721. builder.add_content(end_think);
  722. } else {
  723. builder.add_reasoning_content(unclosed_reasoning_content);
  724. builder.add_reasoning_content(reasoning_content);
  725. }
  726. unclosed_reasoning_content.clear();
  727. }
  728. }
  729. // Handle multiple think block
  730. bool toolcall_in_think = false;
  731. for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
  732. if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
  733. if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
  734. auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
  735. builder.add_reasoning_content(reasoning_content);
  736. think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
  737. } else {
  738. think_start = think_end + end_think.size() - 1;
  739. }
  740. } else {
  741. // This <tool_call> start is in thinking block, skip this tool call
  742. auto pos = think_start + start_think.size();
  743. unclosed_reasoning_content = content.substr(pos) + tool_call_start;
  744. reasoning_unclosed = true;
  745. content.resize(think_start);
  746. toolcall_in_think = true;
  747. }
  748. }
  749. if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
  750. rstrip(content);
  751. // Handle unclosed </think> token from content: delete all </think> token
  752. if (auto pos = content.rfind(end_think); pos != std::string::npos) {
  753. while (pos != std::string::npos) {
  754. pos = erase_spaces(content, pos, pos + end_think.size() - 1);
  755. pos = content.rfind(end_think, pos);
  756. }
  757. }
  758. // Strip if needed
  759. if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
  760. content = string_strip(content);
  761. }
  762. }
  763. // remove potential partial suffix
  764. if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
  765. rstrip(content);
  766. trim_potential_partial_word(content);
  767. rstrip(content);
  768. }
  769. // Add content
  770. if (content.size() != 0) {
  771. // If there are multiple content blocks
  772. if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
  773. builder.add_content("\n\n");
  774. }
  775. builder.add_content(content);
  776. }
  777. // This <tool_call> start is in thinking block, skip this tool call
  778. if (toolcall_in_think && !form.allow_toolcall_in_think) {
  779. continue;
  780. }
  781. // There is no tool call and all content is parsed
  782. if (!tc) {
  783. GGML_ASSERT(builder.pos() == builder.input().size());
  784. GGML_ASSERT(unclosed_reasoning_content.empty());
  785. GGML_ASSERT(!reasoning_unclosed);
  786. break;
  787. }
  788. builder.move_to(tc->groups[0].begin);
  789. if (builder.try_consume_xml_tool_calls(form)) {
  790. auto end_of_tool = builder.pos();
  791. builder.consume_spaces();
  792. if (builder.pos() != builder.input().size()) {
  793. builder.move_to(end_of_tool);
  794. if (!builder.result().content.empty()) {
  795. builder.add_content("\n\n");
  796. }
  797. }
  798. } else {
  799. static const common_regex next_char_regex(".");
  800. auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
  801. rstrip(c);
  802. builder.add_content(c);
  803. }
  804. }
  805. }
  806. /**
  807. * Parse content uses reasoning and XML-Style tool call
  808. * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
  809. */
  810. void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
  811. parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
  812. }