1
0

chat.cpp 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. #include "chat.hpp"
  2. #include "chat-template.hpp"
  3. #include "json-schema-to-grammar.h"
  4. #include "log.h"
  5. #include "minja.hpp"
  6. std::string common_chat_format_name(common_chat_format format) {
  7. switch (format) {
  8. case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
  9. case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
  10. case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
  11. case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
  12. case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
  13. case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
  14. case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
  15. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
  16. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
  17. case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
  18. case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
  19. default:
  20. throw std::runtime_error("Unknown chat format");
  21. }
  22. }
  23. const common_grammar_options grammar_options {
  24. /* .dotall = */ false,
  25. /* .compact_spaces = */ false,
  26. // /* .compact_spaces = */ true,
  27. };
  28. static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
  29. // // https://json.nlohmann.me/features/parsing/sax_interface/
  30. struct json_error_locator : public nlohmann::json_sax<json> {
  31. std::size_t position;
  32. bool found_error;
  33. json_error_locator() : position(0), found_error(false) {}
  34. bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
  35. this->position = position - 1;
  36. this->found_error = true;
  37. return false;
  38. }
  39. bool null() override { return true; }
  40. bool boolean(bool) override { return true; }
  41. bool number_integer(number_integer_t) override { return true; }
  42. bool number_unsigned(number_unsigned_t) override { return true; }
  43. bool number_float(number_float_t, const string_t &) override { return true; }
  44. bool string(string_t &) override { return true; }
  45. bool binary(binary_t &) override { return true; }
  46. bool start_object(std::size_t) override { return true; }
  47. bool key(string_t &) override { return true; }
  48. bool end_object() override { return true; }
  49. bool start_array(std::size_t) override { return true; }
  50. bool end_array() override { return true; }
  51. };
  52. json_error_locator err_loc;
  53. json::sax_parse(it, end, &err_loc);
  54. std::string::const_iterator temptative_end;
  55. if (err_loc.found_error) {
  56. temptative_end = it + err_loc.position;
  57. } else {
  58. temptative_end = end;
  59. }
  60. std::string json_sub {it, temptative_end};
  61. try {
  62. out = json::parse(json_sub);
  63. it = temptative_end;
  64. return true;
  65. } catch (const std::exception &) {
  66. return false;
  67. }
  68. }
  69. /**
  70. * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
  71. * Aggregates the prefix, suffix and in-between text into the content.
  72. */
  73. static common_chat_msg parse_json_tool_calls(
  74. const std::string& input,
  75. const std::optional<std::regex> & trigger_opt,
  76. const std::regex & function_regex,
  77. const std::regex & close_regex) {
  78. std::smatch match;
  79. common_chat_msg result;
  80. result.role = "assistant";
  81. auto end = input.end();
  82. auto it = input.begin();
  83. if (trigger_opt) {
  84. if (!std::regex_search(it, end, match, *trigger_opt)) {
  85. result.content = input;
  86. return result;
  87. }
  88. result.content = match.prefix().str();
  89. it = match.suffix().first;
  90. }
  91. while (it != end) {
  92. std::sregex_iterator rend;
  93. std::sregex_iterator rit(it, end, function_regex);
  94. if (rit == rend) {
  95. fprintf(stderr, "No more tool calls found\n");
  96. result.content += std::string(it, end);
  97. break;
  98. }
  99. auto name = rit->str(1);
  100. result.content += std::string(it, rit->prefix().second);
  101. it = rit->suffix().first;
  102. json arguments;
  103. if (!parse_json(it, end, arguments)) {
  104. throw std::runtime_error("Failed to parse json tool call arguments");
  105. }
  106. if (!std::regex_search(it, end, match, close_regex)) {
  107. throw std::runtime_error("Malformed input, missing closing pattern");
  108. }
  109. it = match.suffix().first;
  110. result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
  111. }
  112. return result;
  113. }
  114. static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
  115. auto content_end = input.find(prefix);
  116. size_t tc_start = std::string::npos;
  117. common_chat_msg result;
  118. result.role = "assistant";
  119. const auto process_tool_calls = [&](const json & tool_calls) {
  120. for (const auto & tool_call : tool_calls) {
  121. const auto & arguments = tool_call["arguments"];
  122. result.tool_calls.push_back({
  123. tool_call["name"],
  124. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  125. tool_call.contains("id") ? tool_call["id"] : "",
  126. });
  127. }
  128. };
  129. if (content_end == std::string::npos) {
  130. result.content = input;
  131. } else {
  132. tc_start = content_end + prefix.size() - rstrip_prefix;
  133. result.content = input.substr(0, content_end);
  134. auto tool_calls = json::parse(input.substr(tc_start));
  135. process_tool_calls(tool_calls);
  136. }
  137. return result;
  138. }
  139. static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
  140. for (const auto & tool : tools) {
  141. if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
  142. LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
  143. continue;
  144. }
  145. fn(tool);
  146. }
  147. }
  148. static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  149. common_chat_params data;
  150. auto tool_call_schemas = json::array();
  151. foreach_function(inputs.tools, [&](const json & tool) {
  152. const auto & function = tool["function"];
  153. auto tool_schema = json {
  154. {"type", "object"},
  155. {"properties", {
  156. {"name", {
  157. {"type", "string"},
  158. {"const", function["name"]},
  159. }},
  160. {"arguments", function["parameters"]},
  161. }},
  162. {"required", json::array({"name", "arguments"})},
  163. };
  164. if (function.contains("description")) {
  165. tool_schema["description"] = function["description"];
  166. }
  167. if (inputs.parallel_tool_calls) {
  168. tool_schema["properties"]["id"] = {
  169. {"type", "string"},
  170. {"minLength", 4},
  171. };
  172. tool_schema["required"].push_back("id");
  173. }
  174. tool_call_schemas.emplace_back(tool_schema);
  175. });
  176. const auto tool_call =
  177. inputs.parallel_tool_calls
  178. ? json {
  179. {"type", "object"},
  180. {"properties", {
  181. {"tool_calls", {
  182. {"type", "array"},
  183. {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
  184. {"anyOf", tool_call_schemas},
  185. }},
  186. {"minItems", 1},
  187. }},
  188. }},
  189. {"required", json::array({"tool_calls"})},
  190. }
  191. : json {
  192. {"type", "object"},
  193. {"properties", {
  194. {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
  195. {"anyOf", tool_call_schemas},
  196. }},
  197. }},
  198. {"required", json::array({"tool_call"})},
  199. };
  200. const auto schema =
  201. inputs.tool_choice != "required"
  202. ? json {
  203. {"anyOf", json::array({
  204. tool_call,
  205. {
  206. {"type", "object"},
  207. {"properties", {
  208. {"response", inputs.json_schema.is_null()
  209. ? json {{"type", "string"}}
  210. : inputs.json_schema
  211. },
  212. }},
  213. {"required", json::array({"response"})},
  214. },
  215. })}
  216. }
  217. : tool_call;
  218. data.grammar_lazy = false;
  219. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  220. builder.add_schema("root", schema);
  221. }, grammar_options);
  222. auto tweaked_messages = common_chat_template::add_system(
  223. inputs.messages,
  224. "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
  225. data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  226. data.format = COMMON_CHAT_FORMAT_GENERIC;
  227. return data;
  228. }
  229. static common_chat_msg common_chat_parse_generic(const std::string & input) {
  230. json data = json::parse(input);
  231. common_chat_msg result;
  232. result.role = "assistant";
  233. if (data.contains("tool_calls")) {
  234. for (const auto & tool_call : data["tool_calls"]) {
  235. result.tool_calls.push_back({
  236. tool_call["name"],
  237. tool_call["arguments"].dump(),
  238. tool_call.contains("id") ? tool_call["id"] : "",
  239. });
  240. }
  241. } else if (data.contains("tool_call")) {
  242. result.tool_calls.push_back({
  243. data["tool_call"]["name"],
  244. data["tool_call"]["arguments"].dump(),
  245. /* id= */ "",
  246. });
  247. } else if (data.contains("response")) {
  248. const auto & response = data["response"];
  249. result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
  250. }
  251. return result;
  252. }
  253. static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  254. common_chat_params data;
  255. data.grammar_lazy = inputs.tool_choice != "required";
  256. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  257. auto schemas = json::array();
  258. foreach_function(inputs.tools, [&](const json & tool) {
  259. const auto & function = tool["function"];
  260. schemas.push_back({
  261. {"type", "object"},
  262. {"properties", {
  263. // Important note: the model is probably trained to take a JSON stringified arguments value.
  264. // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
  265. {"name", {
  266. {"type", "string"},
  267. {"const", function["name"]},
  268. }},
  269. {"arguments", function["parameters"]},
  270. {"id", {
  271. {"type", "string"},
  272. // Nemo's template expects a 9-character alphanumeric ID.
  273. {"pattern", "^[a-zA-Z0-9]{9}$"},
  274. }},
  275. }},
  276. {"required", json::array({"name", "arguments", "id"})},
  277. });
  278. });
  279. auto schema = json {
  280. {"type", "array"},
  281. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  282. {"minItems", 1},
  283. };
  284. if (!inputs.parallel_tool_calls) {
  285. schema["maxItems"] = 1;
  286. }
  287. builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
  288. }, grammar_options);
  289. data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
  290. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  291. data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
  292. return data;
  293. }
  294. static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
  295. return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
  296. }
  297. static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  298. common_chat_params data;
  299. data.grammar_lazy = inputs.tool_choice != "required";
  300. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  301. auto schemas = json::array();
  302. foreach_function(inputs.tools, [&](const json & tool) {
  303. const auto & function = tool["function"];
  304. schemas.push_back({
  305. {"type", "object"},
  306. {"properties", {
  307. {"tool_call_id", {
  308. {"type", "string"},
  309. // Command-R's template expects an integer string.
  310. {"pattern", "^[0-9]{1,10}$"},
  311. }},
  312. {"tool_name", {
  313. {"type", "string"},
  314. {"const", function["name"]},
  315. }},
  316. {"parameters", function["parameters"]},
  317. }},
  318. {"required", json::array({"tool_call_id", "tool_name", "parameters"})},
  319. });
  320. });
  321. auto schema = json {
  322. {"type", "array"},
  323. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  324. {"minItems", 1},
  325. };
  326. if (!inputs.parallel_tool_calls) {
  327. schema["maxItems"] = 1;
  328. }
  329. builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
  330. }, grammar_options);
  331. data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
  332. data.preserved_tokens = {
  333. "<|START_RESPONSE|>",
  334. "<|END_RESPONSE|>",
  335. "<|START_THINKING|>",
  336. "<|END_THINKING|>",
  337. "<|END_ACTION|>",
  338. };
  339. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  340. data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
  341. return data;
  342. }
  343. static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
  344. static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>");
  345. static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
  346. std::smatch match;
  347. common_chat_msg result;
  348. result.role = "assistant";
  349. if (std::regex_match(input, match, response_regex)) {
  350. result.content = match[1].str();
  351. } else if (std::regex_match(input, match, thought_action_regex)) {
  352. result.tool_plan = match[1].str();
  353. auto actions_str = match[2].str();
  354. auto actions = json::parse(actions_str);
  355. for (const auto & action : actions) {
  356. result.tool_calls.push_back({
  357. /* .name = */ action["tool_name"],
  358. /* .arguments = */ action["parameters"].dump(),
  359. /* .id = */ action["tool_call_id"],
  360. });
  361. }
  362. } else {
  363. LOG_ERR("Failed to parse command_r output");
  364. result.content = input;
  365. }
  366. return result;
  367. }
  368. static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
  369. if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
  370. throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
  371. }
  372. const auto & parameters_properties = parameters.at("properties");
  373. const auto & parameters_required = parameters.at("required");
  374. for (const auto & prop : expected_properties) {
  375. if (!parameters_properties.contains(prop)) {
  376. throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
  377. }
  378. if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
  379. throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
  380. }
  381. }
  382. if (parameters_properties.size() != expected_properties.size()) {
  383. throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
  384. }
  385. }
  386. static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
  387. auto builtin_tools = json::array();
  388. common_chat_params data;
  389. data.grammar_lazy = inputs.tool_choice != "required";
  390. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  391. std::vector<std::string> tool_rules;
  392. auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
  393. if (name == "wolfram_alpha") {
  394. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
  395. expect_tool_parameters(name, parameters, {"query"});
  396. } else if (name == "web_search" || name == "brave_search") {
  397. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
  398. expect_tool_parameters(name, parameters, {"query"});
  399. } else if (name == "python" || name == "code_interpreter") {
  400. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
  401. expect_tool_parameters(name, parameters, {"code"});
  402. } else {
  403. return false;
  404. }
  405. std::vector<std::string> kvs;
  406. for (const auto & [key, value] : parameters.at("properties").items()) {
  407. kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
  408. }
  409. tool_rules.push_back(
  410. builder.add_rule(
  411. name + "-call",
  412. "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
  413. builtin_tools.push_back(name);
  414. return true;
  415. };
  416. foreach_function(inputs.tools, [&](const json & tool) {
  417. const auto & function = tool["function"];
  418. std::string name = function["name"];
  419. auto parameters = function["parameters"];
  420. builder.resolve_refs(parameters);
  421. // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
  422. if (allow_python_tag_builtin_tools) {
  423. handle_builtin_tool(name, parameters);
  424. }
  425. tool_rules.push_back(
  426. builder.add_rule(
  427. name + "-call",
  428. "\"{\" space "
  429. "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
  430. "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
  431. builder.add_schema(name + "-args", parameters) +
  432. " \"}\""));
  433. data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
  434. });
  435. data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
  436. data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
  437. data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
  438. data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
  439. data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
  440. data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
  441. if (!builtin_tools.empty()) {
  442. data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
  443. }
  444. builder.add_rule("root", string_join(tool_rules, " | "));
  445. }, grammar_options);
  446. data.additional_stops.push_back("<|eom_id|>");
  447. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
  448. {"tools_in_user_message", false},
  449. {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
  450. });
  451. data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
  452. ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
  453. : COMMON_CHAT_FORMAT_LLAMA_3_X;
  454. return data;
  455. }
  456. static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
  457. // TODO: tighten & simplify the parser, don't accept leading text context.
  458. static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
  459. static std::regex close_regex("\\}");
  460. static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
  461. if (with_builtin_tools) {
  462. std::smatch match;
  463. if (std::regex_match(input, match, builtin_call_regex)) {
  464. auto name = match[1].str();
  465. auto raw_args = match[2].str();
  466. // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
  467. auto it_eq = raw_args.find('=');
  468. auto arg_name = raw_args.substr(0, it_eq);
  469. auto arg_value_str = raw_args.substr(it_eq + 1);
  470. auto arg_value = json::parse(arg_value_str);
  471. return {
  472. /* .role = */ "assistant",
  473. /* .content = */ match.prefix().str(),
  474. /* .tool_calls = */ {
  475. {
  476. /* .name = */ match[1],
  477. /* .arguments = */ (json {
  478. {arg_name, arg_value},
  479. }).dump(),
  480. /* .id = */ "",
  481. },
  482. },
  483. };
  484. }
  485. }
  486. return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
  487. }
  488. static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  489. common_chat_params data;
  490. data.grammar_lazy = inputs.tool_choice != "required";
  491. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  492. std::vector<std::string> tool_rules;
  493. foreach_function(inputs.tools, [&](const json & tool) {
  494. const auto & function = tool["function"];
  495. std::string name = function["name"];
  496. auto parameters = function["parameters"];
  497. auto args_rule = builder.add_schema(name + "-args", parameters);
  498. tool_rules.push_back(builder.add_rule(name + "-call",
  499. "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
  500. });
  501. data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
  502. data.preserved_tokens = {
  503. "<|tool▁sep|>",
  504. "<|tool▁call▁end|>",
  505. };
  506. builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
  507. }, grammar_options);
  508. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  509. data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
  510. return data;
  511. }
  512. static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
  513. static std::regex trigger_regex("<|tool▁calls▁begin|>");
  514. static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
  515. static std::regex close_regex("```<|tool▁call▁end|>");
  516. return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
  517. }
  518. static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  519. fprintf(stderr, "%s\n", __func__);
  520. common_chat_params data;
  521. data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
  522. {"datetime", "Jan 29 2025 13:00:00 GMT"},
  523. {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
  524. }, /* adjust_inputs= */ false);
  525. if (!inputs.tools.is_null() && !inputs.tools.empty()) {
  526. data.grammar_lazy = inputs.tool_choice != "required";
  527. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  528. auto schemas = json::array();
  529. foreach_function(inputs.tools, [&](const json & tool) {
  530. const auto & function = tool["function"];
  531. schemas.push_back({
  532. {"type", "object"},
  533. {"properties", {
  534. {"name", {
  535. {"type", "string"},
  536. {"const", function["name"]},
  537. }},
  538. {"arguments", function["parameters"]},
  539. }},
  540. {"required", json::array({"name", "arguments", "id"})},
  541. });
  542. });
  543. auto schema = json {
  544. {"type", "array"},
  545. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  546. {"minItems", 1},
  547. };
  548. if (!inputs.parallel_tool_calls) {
  549. schema["maxItems"] = 1;
  550. }
  551. builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
  552. }, grammar_options);
  553. data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
  554. data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
  555. } else {
  556. data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  557. }
  558. return data;
  559. }
  560. static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
  561. return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
  562. }
  563. static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  564. // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
  565. // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
  566. common_chat_params data;
  567. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  568. data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
  569. if (!inputs.tools.is_null() && !inputs.tools.empty()) {
  570. data.grammar_lazy = inputs.tool_choice != "required";
  571. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  572. std::vector<std::string> first_tool_rules;
  573. std::vector<std::string> subsequent_tool_rules;
  574. foreach_function(inputs.tools, [&](const json & tool) {
  575. const auto & function = tool["function"];
  576. std::string name = function["name"];
  577. auto parameters = function["parameters"];
  578. auto args_rule = builder.add_schema(name + "-args", parameters);
  579. first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
  580. subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
  581. data.grammar_triggers.push_back({name, /* .at_start = */ true});
  582. data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
  583. });
  584. auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
  585. if (inputs.parallel_tool_calls) {
  586. auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
  587. builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
  588. } else {
  589. builder.add_rule("root", first_rule);
  590. }
  591. }, grammar_options);
  592. }
  593. return data;
  594. }
  595. static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
  596. auto expected_it = expected.begin();
  597. auto tmp_it = it;
  598. while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
  599. ++tmp_it;
  600. ++expected_it;
  601. }
  602. if (expected_it == expected.end()) {
  603. it = tmp_it;
  604. return true;
  605. }
  606. return false;
  607. }
  608. static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
  609. static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
  610. static std::regex close_regex(R"($|(?=>>>))");
  611. std::string content;
  612. auto it = input.begin();
  613. const auto end = input.end();
  614. if (consume(it, end, "all\n")) {
  615. std::smatch match;
  616. if (std::regex_search(it, end, match, function_regex)) {
  617. auto fun_it = match.prefix().second;
  618. content = std::string(it, fun_it);
  619. it = fun_it;
  620. } else {
  621. common_chat_msg res;
  622. res.role = "assistant";
  623. res.content = std::string(it, end);
  624. return res;
  625. }
  626. }
  627. // TODO: tighten & simplify.
  628. try {
  629. auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
  630. res.content = content + res.content;
  631. return res;
  632. } catch (const std::exception & e) {
  633. LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
  634. common_chat_msg res;
  635. res.role = "assistant";
  636. res.content = input;
  637. return res;
  638. }
  639. }
  640. static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  641. // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
  642. common_chat_params data;
  643. json tools = inputs.tools.is_null() ? inputs.tools : json::array();
  644. std::string python_code_argument_name;
  645. auto has_raw_python = false;
  646. data.grammar_lazy = inputs.tool_choice != "required";
  647. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  648. std::vector<std::string> tool_rules;
  649. foreach_function(inputs.tools, [&](const json & tool) {
  650. const auto & function = tool["function"];
  651. const auto & parameters = function["parameters"];
  652. std::string name = function["name"];
  653. if (name == "python" || name == "ipython") {
  654. if (!parameters.contains("type")) {
  655. throw std::runtime_error("Missing type in python tool");
  656. }
  657. has_raw_python = true;
  658. auto type = parameters.at("type");
  659. if (type == "object") {
  660. auto properties = parameters.at("properties");
  661. for (auto it = properties.begin(); it != properties.end(); ++it) {
  662. if (it.value().at("type") == "string") {
  663. if (!python_code_argument_name.empty()) {
  664. throw std::runtime_error("Multiple string arguments found in python tool");
  665. }
  666. python_code_argument_name = it.key();
  667. }
  668. }
  669. if (python_code_argument_name.empty()) {
  670. throw std::runtime_error("No string argument found in python tool");
  671. }
  672. } else if (type != "string") {
  673. throw std::runtime_error("Invalid type in python tool: " + type.dump());
  674. }
  675. }
  676. tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
  677. });
  678. if (has_raw_python) {
  679. tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
  680. data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
  681. }
  682. auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
  683. builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
  684. data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
  685. }, grammar_options);
  686. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  687. // TODO: if (has_raw_python)
  688. data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
  689. return data;
  690. }
  691. static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
  692. // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
  693. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
  694. std::smatch match;
  695. if (std::regex_search(input, match, python_tag_regex)) {
  696. auto code = match[1].str();
  697. return {
  698. /* .role = */ "assistant",
  699. /* .content = */ match.prefix().str(),
  700. /* .tool_calls = */ {
  701. {
  702. /* .name = */ "python",
  703. /* .arguments = */ (json {{"code", code}}).dump(),
  704. /* .id = */ "",
  705. },
  706. }
  707. };
  708. }
  709. static std::regex function_regex(R"(<function=(\w+)>)");
  710. static std::regex close_regex(R"(</function>)");
  711. // TODO: tighten & simplify.
  712. return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
  713. }
  714. static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  715. common_chat_params data;
  716. // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
  717. data.grammar_lazy = inputs.tool_choice != "required";
  718. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  719. std::vector<std::string> tool_rules;
  720. foreach_function(inputs.tools, [&](const json & tool) {
  721. const auto & function = tool["function"];
  722. std::string name = function["name"];
  723. auto parameters = function["parameters"];
  724. builder.resolve_refs(parameters);
  725. tool_rules.push_back(builder.add_schema(name + "-call", {
  726. {"type", "object"},
  727. {"properties", json {
  728. {"name", json {{"const", name}}},
  729. {"arguments", parameters},
  730. }},
  731. {"required", json::array({"name", "arguments"})},
  732. }));
  733. });
  734. auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
  735. builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
  736. data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
  737. data.preserved_tokens = { "</tool_call>" };
  738. }, grammar_options);
  739. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  740. data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
  741. return data;
  742. }
  743. static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
  744. try {
  745. std::regex start_pattern(R"([\n\s]*<tool_call>)");
  746. std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
  747. std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
  748. auto end = input.end();
  749. std::sregex_iterator rend;
  750. std::sregex_iterator rit(input.begin(), end, start_pattern);
  751. if (rit == rend) {
  752. return {
  753. /* .role = */ "assistant",
  754. /* .content = */ input,
  755. /* .tool_calls = */ {},
  756. };
  757. }
  758. common_chat_msg result;
  759. result.role = "assistant";
  760. result.content = rit->prefix();
  761. auto it = rit->suffix().first;
  762. while (it != end) {
  763. json call;
  764. if (!parse_json(it, end, call)) {
  765. throw std::runtime_error("Failed to parse json tool call");
  766. }
  767. const auto & arguments = call["arguments"];
  768. result.tool_calls.push_back({
  769. call["name"],
  770. arguments.dump(),
  771. // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  772. /* id= */ "",
  773. });
  774. rit = {it, end, middle_pattern};
  775. if (rit != rend) {
  776. it = rit->suffix().first;
  777. } else {
  778. rit = {it, end, end_pattern};
  779. if (rit == rend) {
  780. throw std::runtime_error("Malformed input, missing </tool_call>");
  781. }
  782. break;
  783. }
  784. }
  785. return result;
  786. } catch (const std::exception & e) {
  787. return {
  788. /* .role = */ "assistant",
  789. /* .content = */ input,
  790. /* .tool_calls = */ {},
  791. };
  792. }
  793. }
  794. static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  795. common_chat_params data;
  796. data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  797. data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  798. data.grammar_lazy = false;
  799. if (!inputs.json_schema.is_null()) {
  800. if (!inputs.grammar.empty()) {
  801. throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
  802. }
  803. data.grammar = json_schema_to_grammar(inputs.json_schema);
  804. } else {
  805. data.grammar = inputs.grammar.empty();
  806. }
  807. return data;
  808. }
  809. common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  810. auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
  811. LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false");
  812. if (has_tools && !inputs.grammar.empty()) {
  813. throw std::runtime_error("Cannot specify grammar with tools");
  814. }
  815. const auto & src = tmpl.source();
  816. if (src.find(">>>all") != std::string::npos) {
  817. // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
  818. return common_chat_params_init_functionary_v3_2(tmpl, inputs);
  819. }
  820. if (src.find(" functools[") != std::string::npos) {
  821. // Firefunction v2 requires datetime and functions in the context, even w/o tools.
  822. return common_chat_params_init_firefunction_v2(tmpl, inputs);
  823. }
  824. if (!has_tools) {
  825. return common_chat_params_init_without_tools(tmpl, inputs);
  826. }
  827. if (src.find("<tool_call>") != std::string::npos) {
  828. return common_chat_params_init_hermes_2_pro(tmpl, inputs);
  829. }
  830. if (src.find("<|start_header_id|>") != std::string::npos
  831. && src.find("<function=") != std::string::npos) {
  832. return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
  833. }
  834. if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
  835. auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
  836. return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
  837. }
  838. if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
  839. return common_chat_params_init_deepseek_r1(tmpl, inputs);
  840. }
  841. if (src.find("[TOOL_CALLS]") != std::string::npos) {
  842. return common_chat_params_init_mistral_nemo(tmpl, inputs);
  843. }
  844. if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
  845. return common_chat_params_init_command_r7b(tmpl, inputs);
  846. }
  847. return common_chat_params_init_generic(tmpl, inputs);
  848. }
  849. static common_chat_msg common_chat_parse_content_only(const std::string & input) {
  850. return {
  851. /* .role = */ "assistant",
  852. /* .content = */ input,
  853. /* .tool_calls = */ {},
  854. };
  855. }
  856. common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
  857. switch (format) {
  858. case COMMON_CHAT_FORMAT_CONTENT_ONLY:
  859. return common_chat_parse_content_only(input);
  860. case COMMON_CHAT_FORMAT_GENERIC:
  861. return common_chat_parse_generic(input);
  862. case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
  863. return common_chat_parse_mistral_nemo(input);
  864. case COMMON_CHAT_FORMAT_LLAMA_3_X:
  865. return common_chat_parse_llama_3_1(input);
  866. case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
  867. return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
  868. case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
  869. return common_chat_parse_deepseek_r1(input);
  870. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
  871. return common_chat_parse_functionary_v3_2(input);
  872. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
  873. return common_chat_parse_functionary_v3_1_llama_3_1(input);
  874. case COMMON_CHAT_FORMAT_HERMES_2_PRO:
  875. return common_chat_parse_hermes_2_pro(input);
  876. case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
  877. return common_chat_parse_firefunction_v2(input);
  878. case COMMON_CHAT_FORMAT_COMMAND_R7B:
  879. return common_chat_parse_command_r7b(input);
  880. default:
  881. throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
  882. }
  883. }