chat.cpp 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. #include "chat.hpp"
  2. #include "chat-template.hpp"
  3. #include "json-schema-to-grammar.h"
  4. #include "log.h"
  5. #include "minja.hpp"
  6. std::string common_chat_format_name(common_chat_format format) {
  7. switch (format) {
  8. case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
  9. case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
  10. case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
  11. case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
  12. case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
  13. case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
  14. case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
  15. case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
  16. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
  17. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
  18. case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
  19. case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
  20. case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
  21. default:
  22. throw std::runtime_error("Unknown chat format");
  23. }
  24. }
  25. const common_grammar_options grammar_options {
  26. /* .dotall = */ false,
  27. /* .compact_spaces = */ false,
  28. // /* .compact_spaces = */ true,
  29. };
  30. static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
  31. // // https://json.nlohmann.me/features/parsing/sax_interface/
  32. struct json_error_locator : public nlohmann::json_sax<json> {
  33. std::size_t position;
  34. bool found_error;
  35. json_error_locator() : position(0), found_error(false) {}
  36. bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
  37. this->position = position - 1;
  38. this->found_error = true;
  39. return false;
  40. }
  41. bool null() override { return true; }
  42. bool boolean(bool) override { return true; }
  43. bool number_integer(number_integer_t) override { return true; }
  44. bool number_unsigned(number_unsigned_t) override { return true; }
  45. bool number_float(number_float_t, const string_t &) override { return true; }
  46. bool string(string_t &) override { return true; }
  47. bool binary(binary_t &) override { return true; }
  48. bool start_object(std::size_t) override { return true; }
  49. bool key(string_t &) override { return true; }
  50. bool end_object() override { return true; }
  51. bool start_array(std::size_t) override { return true; }
  52. bool end_array() override { return true; }
  53. };
  54. json_error_locator err_loc;
  55. json::sax_parse(it, end, &err_loc);
  56. std::string::const_iterator temptative_end;
  57. if (err_loc.found_error) {
  58. temptative_end = it + err_loc.position;
  59. } else {
  60. temptative_end = end;
  61. }
  62. std::string json_sub {it, temptative_end};
  63. try {
  64. out = json::parse(json_sub);
  65. it = temptative_end;
  66. return true;
  67. } catch (const std::exception &) {
  68. return false;
  69. }
  70. }
  71. /**
  72. * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
  73. * Aggregates the prefix, suffix and in-between text into the content.
  74. */
  75. static common_chat_msg parse_json_tool_calls(
  76. const std::string& input,
  77. const std::optional<std::regex> & trigger_opt,
  78. const std::regex & function_regex,
  79. const std::regex & close_regex) {
  80. std::smatch match;
  81. common_chat_msg result;
  82. result.role = "assistant";
  83. auto end = input.end();
  84. auto it = input.begin();
  85. if (trigger_opt) {
  86. if (!std::regex_search(it, end, match, *trigger_opt)) {
  87. result.content = input;
  88. return result;
  89. }
  90. result.content = match.prefix().str();
  91. it = match.suffix().first;
  92. }
  93. while (it != end) {
  94. std::sregex_iterator rend;
  95. std::sregex_iterator rit(it, end, function_regex);
  96. if (rit == rend) {
  97. result.content += std::string(it, end);
  98. break;
  99. }
  100. auto name = rit->str(1);
  101. result.content += std::string(it, rit->prefix().second);
  102. it = rit->suffix().first;
  103. json arguments;
  104. if (!parse_json(it, end, arguments)) {
  105. throw std::runtime_error("Failed to parse json tool call arguments: " + input);
  106. }
  107. if (!std::regex_search(it, end, match, close_regex)) {
  108. throw std::runtime_error("Malformed input, missing closing pattern: " + input);
  109. }
  110. it = match.suffix().first;
  111. result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
  112. }
  113. if (!result.tool_calls.empty()) {
  114. if (!string_strip(result.content).empty()) {
  115. LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
  116. }
  117. result.content = "";
  118. }
  119. return result;
  120. }
  121. static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
  122. auto content_end = input.find(prefix);
  123. size_t tc_start = std::string::npos;
  124. common_chat_msg result;
  125. result.role = "assistant";
  126. const auto process_tool_calls = [&](const json & tool_calls) {
  127. for (const auto & tool_call : tool_calls) {
  128. const auto & arguments = tool_call.at("arguments");
  129. result.tool_calls.push_back({
  130. tool_call.at("name"),
  131. arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  132. tool_call.contains("id") ? tool_call.at("id") : "",
  133. });
  134. }
  135. };
  136. if (content_end == std::string::npos) {
  137. result.content = input;
  138. } else {
  139. tc_start = content_end + prefix.size() - rstrip_prefix;
  140. result.content = input.substr(0, content_end);
  141. auto tool_calls = json::parse(input.substr(tc_start));
  142. process_tool_calls(tool_calls);
  143. }
  144. return result;
  145. }
  146. static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
  147. for (const auto & tool : tools) {
  148. if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
  149. LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
  150. continue;
  151. }
  152. fn(tool);
  153. }
  154. }
  155. static std::string apply(
  156. const common_chat_template & tmpl,
  157. const nlohmann::ordered_json & messages,
  158. const nlohmann::ordered_json & tools,
  159. bool add_generation_prompt,
  160. const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
  161. {
  162. minja::chat_template_inputs tmpl_inputs;
  163. tmpl_inputs.messages = messages;
  164. tmpl_inputs.tools = tools;
  165. tmpl_inputs.add_generation_prompt = add_generation_prompt;
  166. tmpl_inputs.extra_context = extra_context;
  167. // TODO: add flag to control date/time, if only for testing purposes.
  168. // tmpl_inputs.now = std::chrono::system_clock::now();
  169. minja::chat_template_options tmpl_opts;
  170. tmpl_opts.use_bos_token = false;
  171. tmpl_opts.use_eos_token = false;
  172. return tmpl.apply(tmpl_inputs, tmpl_opts);
  173. }
  174. static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  175. common_chat_params data;
  176. auto tool_call_schemas = json::array();
  177. foreach_function(inputs.tools, [&](const json & tool) {
  178. const auto & function = tool.at("function");
  179. auto tool_schema = json {
  180. {"type", "object"},
  181. {"properties", {
  182. {"name", {
  183. {"type", "string"},
  184. {"const", function.at("name")},
  185. }},
  186. {"arguments", function.at("parameters")},
  187. }},
  188. {"required", json::array({"name", "arguments"})},
  189. };
  190. if (function.contains("description")) {
  191. tool_schema["description"] = function.at("description");
  192. }
  193. if (inputs.parallel_tool_calls) {
  194. tool_schema.at("properties")["id"] = {
  195. {"type", "string"},
  196. {"minLength", 4},
  197. };
  198. tool_schema.at("required").push_back("id");
  199. }
  200. tool_call_schemas.emplace_back(tool_schema);
  201. });
  202. const auto tool_call =
  203. inputs.parallel_tool_calls
  204. ? json {
  205. {"type", "object"},
  206. {"properties", {
  207. {"tool_calls", {
  208. {"type", "array"},
  209. {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
  210. {"anyOf", tool_call_schemas},
  211. }},
  212. {"minItems", 1},
  213. }},
  214. }},
  215. {"required", json::array({"tool_calls"})},
  216. }
  217. : json {
  218. {"type", "object"},
  219. {"properties", {
  220. {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
  221. {"anyOf", tool_call_schemas},
  222. }},
  223. }},
  224. {"required", json::array({"tool_call"})},
  225. };
  226. const auto schema =
  227. inputs.tool_choice != "required"
  228. ? json {
  229. {"anyOf", json::array({
  230. tool_call,
  231. {
  232. {"type", "object"},
  233. {"properties", {
  234. {"response", inputs.json_schema.is_null()
  235. ? json {{"type", "string"}}
  236. : inputs.json_schema
  237. },
  238. }},
  239. {"required", json::array({"response"})},
  240. },
  241. })}
  242. }
  243. : tool_call;
  244. data.grammar_lazy = false;
  245. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  246. builder.add_schema("root", schema);
  247. }, grammar_options);
  248. auto tweaked_messages = common_chat_template::add_system(
  249. inputs.messages,
  250. "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
  251. data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  252. data.format = COMMON_CHAT_FORMAT_GENERIC;
  253. return data;
  254. }
  255. static common_chat_msg common_chat_parse_generic(const std::string & input) {
  256. json data = json::parse(input);
  257. common_chat_msg result;
  258. result.role = "assistant";
  259. if (data.contains("tool_calls")) {
  260. for (const auto & tool_call : data.at("tool_calls")) {
  261. result.tool_calls.push_back({
  262. tool_call.at("name"),
  263. tool_call.at("arguments").dump(),
  264. tool_call.contains("id") ? tool_call.at("id") : "",
  265. });
  266. }
  267. } else if (data.contains("tool_call")) {
  268. result.tool_calls.push_back({
  269. data.at("tool_call").at("name"),
  270. data.at("tool_call").at("arguments").dump(),
  271. /* id= */ "",
  272. });
  273. } else if (data.contains("response")) {
  274. const auto & response = data.at("response");
  275. result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
  276. }
  277. return result;
  278. }
  279. static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  280. common_chat_params data;
  281. data.grammar_lazy = inputs.tool_choice != "required";
  282. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  283. auto schemas = json::array();
  284. foreach_function(inputs.tools, [&](const json & tool) {
  285. const auto & function = tool.at("function");
  286. schemas.push_back({
  287. {"type", "object"},
  288. {"properties", {
  289. // Important note: the model is probably trained to take a JSON stringified arguments value.
  290. // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
  291. {"name", {
  292. {"type", "string"},
  293. {"const", function.at("name")},
  294. }},
  295. {"arguments", function.at("parameters")},
  296. {"id", {
  297. {"type", "string"},
  298. // Nemo's template expects a 9-character alphanumeric ID.
  299. {"pattern", "^[a-zA-Z0-9]{9}$"},
  300. }},
  301. }},
  302. {"required", json::array({"name", "arguments", "id"})},
  303. });
  304. });
  305. auto schema = json {
  306. {"type", "array"},
  307. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  308. {"minItems", 1},
  309. };
  310. if (!inputs.parallel_tool_calls) {
  311. schema["maxItems"] = 1;
  312. }
  313. builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
  314. }, grammar_options);
  315. data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
  316. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  317. data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
  318. return data;
  319. }
  320. static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
  321. return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
  322. }
  323. static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  324. common_chat_params data;
  325. data.grammar_lazy = inputs.tool_choice != "required";
  326. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  327. auto schemas = json::array();
  328. foreach_function(inputs.tools, [&](const json & tool) {
  329. const auto & function = tool.at("function");
  330. schemas.push_back({
  331. {"type", "object"},
  332. {"properties", {
  333. {"tool_call_id", {
  334. {"type", "string"},
  335. // Command-R's template expects an integer string.
  336. {"pattern", "^[0-9]{1,10}$"},
  337. }},
  338. {"tool_name", {
  339. {"type", "string"},
  340. {"const", function.at("name")},
  341. }},
  342. {"parameters", function.at("parameters")},
  343. }},
  344. {"required", json::array({"tool_call_id", "tool_name", "parameters"})},
  345. });
  346. });
  347. auto schema = json {
  348. {"type", "array"},
  349. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  350. {"minItems", 1},
  351. };
  352. if (!inputs.parallel_tool_calls) {
  353. schema["maxItems"] = 1;
  354. }
  355. builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
  356. }, grammar_options);
  357. data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
  358. data.preserved_tokens = {
  359. "<|START_RESPONSE|>",
  360. "<|END_RESPONSE|>",
  361. "<|START_THINKING|>",
  362. "<|END_THINKING|>",
  363. "<|END_ACTION|>",
  364. };
  365. auto adjusted_messages = json::array();
  366. for (const auto & msg : inputs.messages) {
  367. auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
  368. auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
  369. if (has_reasoning_content && has_tool_calls) {
  370. auto adjusted_message = msg;
  371. adjusted_message["tool_plan"] = msg.at("reasoning_content");
  372. adjusted_message.erase("reasoning_content");
  373. adjusted_messages.push_back(adjusted_message);
  374. } else {
  375. adjusted_messages.push_back(msg);
  376. }
  377. }
  378. data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
  379. data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
  380. return data;
  381. }
  382. static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
  383. static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
  384. static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
  385. static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
  386. std::smatch match;
  387. common_chat_msg result;
  388. result.role = "assistant";
  389. std::string rest = input;
  390. if (std::regex_match(rest, match, thought_regex)) {
  391. if (extract_reasoning) {
  392. result.reasoning_content = match[2].str();
  393. } else if (!match[2].str().empty()) {
  394. // Let the unparsed thinking tags through in content only if their insides aren't empty.
  395. result.content = match[1].str();
  396. }
  397. rest = match[3].str();
  398. }
  399. if (std::regex_match(rest, match, action_regex)) {
  400. auto actions_str = match[1].str();
  401. auto actions = json::parse(actions_str);
  402. for (const auto & action : actions) {
  403. result.tool_calls.push_back({
  404. /* .name = */ action.at("tool_name"),
  405. /* .arguments = */ action.at("parameters").dump(),
  406. /* .id = */ action.at("tool_call_id"),
  407. });
  408. }
  409. } else if (std::regex_match(rest, match, response_regex)) {
  410. auto response = match[1].str();
  411. result.content += response;
  412. } else {
  413. result.content += rest;
  414. }
  415. return result;
  416. }
  417. static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
  418. if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
  419. throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
  420. }
  421. const auto & parameters_properties = parameters.at("properties");
  422. const auto & parameters_required = parameters.at("required");
  423. for (const auto & prop : expected_properties) {
  424. if (!parameters_properties.contains(prop)) {
  425. throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
  426. }
  427. if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
  428. throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
  429. }
  430. }
  431. if (parameters_properties.size() != expected_properties.size()) {
  432. throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
  433. }
  434. }
  435. static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
  436. auto builtin_tools = json::array();
  437. common_chat_params data;
  438. data.grammar_lazy = inputs.tool_choice != "required";
  439. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  440. std::vector<std::string> tool_rules;
  441. auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
  442. if (name == "wolfram_alpha") {
  443. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
  444. expect_tool_parameters(name, parameters, {"query"});
  445. } else if (name == "web_search" || name == "brave_search") {
  446. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
  447. expect_tool_parameters(name, parameters, {"query"});
  448. } else if (name == "python" || name == "code_interpreter") {
  449. // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
  450. expect_tool_parameters(name, parameters, {"code"});
  451. } else {
  452. return false;
  453. }
  454. std::vector<std::string> kvs;
  455. for (const auto & [key, value] : parameters.at("properties").items()) {
  456. kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
  457. }
  458. tool_rules.push_back(
  459. builder.add_rule(
  460. name + "-call",
  461. "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
  462. builtin_tools.push_back(name);
  463. return true;
  464. };
  465. foreach_function(inputs.tools, [&](const json & tool) {
  466. const auto & function = tool.at("function");
  467. std::string name = function.at("name");
  468. auto parameters = function.at("parameters");
  469. builder.resolve_refs(parameters);
  470. // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
  471. if (allow_python_tag_builtin_tools) {
  472. handle_builtin_tool(name, parameters);
  473. }
  474. tool_rules.push_back(
  475. builder.add_rule(
  476. name + "-call",
  477. "\"{\" space "
  478. "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
  479. "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
  480. builder.add_schema(name + "-args", parameters) +
  481. " \"}\""));
  482. data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
  483. });
  484. data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
  485. data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
  486. data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
  487. data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
  488. data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
  489. data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
  490. if (!builtin_tools.empty()) {
  491. data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
  492. }
  493. builder.add_rule("root", string_join(tool_rules, " | "));
  494. }, grammar_options);
  495. data.additional_stops.push_back("<|eom_id|>");
  496. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
  497. {"tools_in_user_message", false},
  498. {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
  499. });
  500. data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
  501. ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
  502. : COMMON_CHAT_FORMAT_LLAMA_3_X;
  503. return data;
  504. }
  505. static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
  506. // TODO: tighten & simplify the parser, don't accept leading text context.
  507. static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
  508. static std::regex close_regex("\\}");
  509. static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
  510. if (with_builtin_tools) {
  511. std::smatch match;
  512. if (std::regex_match(input, match, builtin_call_regex)) {
  513. auto name = match[1].str();
  514. auto raw_args = match[2].str();
  515. // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
  516. auto it_eq = raw_args.find('=');
  517. auto arg_name = raw_args.substr(0, it_eq);
  518. auto arg_value_str = raw_args.substr(it_eq + 1);
  519. auto arg_value = json::parse(arg_value_str);
  520. return {
  521. /* .role = */ "assistant",
  522. /* .content = */ match.prefix().str(),
  523. /* .tool_calls = */ {
  524. {
  525. /* .name = */ match[1],
  526. /* .arguments = */ (json {
  527. {arg_name, arg_value},
  528. }).dump(),
  529. /* .id = */ "",
  530. },
  531. },
  532. };
  533. }
  534. }
  535. return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
  536. }
  537. static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  538. common_chat_params data;
  539. if (inputs.tools.is_array() && !inputs.tools.empty()) {
  540. data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
  541. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  542. std::vector<std::string> tool_rules;
  543. foreach_function(inputs.tools, [&](const json & tool) {
  544. const auto & function = tool.at("function");
  545. std::string name = function.at("name");
  546. auto parameters = function.at("parameters");
  547. auto args_rule = builder.add_schema(name + "-args", parameters);
  548. tool_rules.push_back(builder.add_rule(name + "-call",
  549. "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
  550. "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
  551. });
  552. // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
  553. // so we accept common variants (then it's all constrained)
  554. builder.add_rule("root",
  555. "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
  556. "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
  557. "\"<|tool▁calls▁end|>\""
  558. " space");
  559. data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
  560. data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
  561. data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
  562. data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
  563. data.preserved_tokens = {
  564. "<think>",
  565. "</think>",
  566. "<|tool▁sep|>",
  567. "<|tool▁calls▁end|",
  568. "<|tool▁call▁end|>",
  569. };
  570. }, grammar_options);
  571. }
  572. auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  573. // Hacks to fix the official (broken) prompt.
  574. // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
  575. // until the official template is fixed.
  576. if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
  577. // Don't leave the chat dangling after tool results
  578. if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
  579. prompt += "<|end▁of▁sentence|>";
  580. if (inputs.add_generation_prompt) {
  581. prompt += "<|Assistant|>";
  582. }
  583. }
  584. // Fix up tool call delta example added by Minja
  585. prompt = std::regex_replace(
  586. prompt,
  587. std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
  588. "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
  589. }
  590. data.prompt = prompt;
  591. data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
  592. return data;
  593. }
  594. static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
  595. static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
  596. static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
  597. static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
  598. static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
  599. common_chat_msg msg;
  600. msg.role = "assistant";
  601. std::smatch match;
  602. if (std::regex_match(input, match, reasoning_content_regex)) {
  603. std::string rest;
  604. if (extract_reasoning) {
  605. msg.reasoning_content = string_strip(match[2].str());
  606. } else {
  607. msg.content = match[1].str();
  608. }
  609. rest = match[3].str();
  610. if (std::regex_search(rest, match, tool_calls_regex)) {
  611. auto tool_calls = match[1].str();
  612. auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
  613. msg.tool_calls = std::move(msg2.tool_calls);
  614. } else {
  615. msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
  616. }
  617. } else {
  618. msg.content = input;
  619. }
  620. return msg;
  621. }
  622. static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  623. fprintf(stderr, "%s\n", __func__);
  624. common_chat_params data;
  625. data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
  626. {"datetime", "Jan 29 2025 13:00:00 GMT"},
  627. {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
  628. });
  629. if (inputs.tools.is_array() && !inputs.tools.empty()) {
  630. data.grammar_lazy = inputs.tool_choice != "required";
  631. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  632. auto schemas = json::array();
  633. foreach_function(inputs.tools, [&](const json & tool) {
  634. const auto & function = tool.at("function");
  635. schemas.push_back({
  636. {"type", "object"},
  637. {"properties", {
  638. {"name", {
  639. {"type", "string"},
  640. {"const", function.at("name")},
  641. }},
  642. {"arguments", function.at("parameters")},
  643. }},
  644. {"required", json::array({"name", "arguments", "id"})},
  645. });
  646. });
  647. auto schema = json {
  648. {"type", "array"},
  649. {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
  650. {"minItems", 1},
  651. };
  652. if (!inputs.parallel_tool_calls) {
  653. schema["maxItems"] = 1;
  654. }
  655. builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
  656. }, grammar_options);
  657. data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
  658. data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
  659. } else {
  660. data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  661. }
  662. return data;
  663. }
  664. static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
  665. return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
  666. }
  667. static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  668. // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
  669. // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
  670. common_chat_params data;
  671. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  672. data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
  673. if (inputs.tools.is_array() && !inputs.tools.empty()) {
  674. data.grammar_lazy = inputs.tool_choice != "required";
  675. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  676. std::vector<std::string> first_tool_rules;
  677. std::vector<std::string> subsequent_tool_rules;
  678. foreach_function(inputs.tools, [&](const json & tool) {
  679. const auto & function = tool.at("function");
  680. std::string name = function.at("name");
  681. auto parameters = function.at("parameters");
  682. auto args_rule = builder.add_schema(name + "-args", parameters);
  683. first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
  684. subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
  685. data.grammar_triggers.push_back({name, /* .at_start = */ true});
  686. data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
  687. });
  688. auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
  689. if (inputs.parallel_tool_calls) {
  690. auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
  691. builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
  692. } else {
  693. builder.add_rule("root", first_rule);
  694. }
  695. }, grammar_options);
  696. }
  697. return data;
  698. }
  699. static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
  700. auto expected_it = expected.begin();
  701. auto tmp_it = it;
  702. while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
  703. ++tmp_it;
  704. ++expected_it;
  705. }
  706. if (expected_it == expected.end()) {
  707. it = tmp_it;
  708. return true;
  709. }
  710. return false;
  711. }
  712. static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
  713. static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
  714. static std::regex close_regex(R"($|(?=>>>))");
  715. std::string content;
  716. auto it = input.begin();
  717. const auto end = input.end();
  718. if (consume(it, end, "all\n")) {
  719. std::smatch match;
  720. if (std::regex_search(it, end, match, function_regex)) {
  721. auto fun_it = match.prefix().second;
  722. content = std::string(it, fun_it);
  723. it = fun_it;
  724. } else {
  725. common_chat_msg res;
  726. res.role = "assistant";
  727. res.content = std::string(it, end);
  728. return res;
  729. }
  730. }
  731. // TODO: tighten & simplify.
  732. try {
  733. auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
  734. res.content = content + res.content;
  735. return res;
  736. } catch (const std::exception & e) {
  737. LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
  738. common_chat_msg res;
  739. res.role = "assistant";
  740. res.content = input;
  741. return res;
  742. }
  743. }
  744. static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  745. // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
  746. common_chat_params data;
  747. json tools = inputs.tools.is_null() ? inputs.tools : json::array();
  748. std::string python_code_argument_name;
  749. auto has_raw_python = false;
  750. data.grammar_lazy = inputs.tool_choice != "required";
  751. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  752. std::vector<std::string> tool_rules;
  753. foreach_function(inputs.tools, [&](const json & tool) {
  754. const auto & function = tool.at("function");
  755. const auto & parameters = function.at("parameters");
  756. std::string name = function.at("name");
  757. if (name == "python" || name == "ipython") {
  758. if (!parameters.contains("type")) {
  759. throw std::runtime_error("Missing type in python tool");
  760. }
  761. has_raw_python = true;
  762. auto type = parameters.at("type");
  763. if (type == "object") {
  764. auto properties = parameters.at("properties");
  765. for (auto it = properties.begin(); it != properties.end(); ++it) {
  766. if (it.value().at("type") == "string") {
  767. if (!python_code_argument_name.empty()) {
  768. throw std::runtime_error("Multiple string arguments found in python tool");
  769. }
  770. python_code_argument_name = it.key();
  771. }
  772. }
  773. if (python_code_argument_name.empty()) {
  774. throw std::runtime_error("No string argument found in python tool");
  775. }
  776. } else if (type != "string") {
  777. throw std::runtime_error("Invalid type in python tool: " + type.dump());
  778. }
  779. }
  780. tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
  781. });
  782. if (has_raw_python) {
  783. tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
  784. data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
  785. }
  786. auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
  787. builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
  788. data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
  789. }, grammar_options);
  790. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  791. // TODO: if (has_raw_python)
  792. data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
  793. return data;
  794. }
  795. static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
  796. // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
  797. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
  798. std::smatch match;
  799. if (std::regex_search(input, match, python_tag_regex)) {
  800. auto code = match[1].str();
  801. return {
  802. /* .role = */ "assistant",
  803. /* .content = */ match.prefix().str(),
  804. /* .tool_calls = */ {
  805. {
  806. /* .name = */ "python",
  807. /* .arguments = */ (json {{"code", code}}).dump(),
  808. /* .id = */ "",
  809. },
  810. }
  811. };
  812. }
  813. static std::regex function_regex(R"(<function=(\w+)>)");
  814. static std::regex close_regex(R"(</function>)");
  815. // TODO: tighten & simplify.
  816. return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
  817. }
  818. static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  819. common_chat_params data;
  820. // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
  821. data.grammar_lazy = inputs.tool_choice != "required";
  822. data.grammar = build_grammar([&](const common_grammar_builder & builder) {
  823. std::vector<std::string> tool_rules;
  824. foreach_function(inputs.tools, [&](const json & tool) {
  825. const auto & function = tool.at("function");
  826. std::string name = function.at("name");
  827. auto parameters = function.at("parameters");
  828. builder.resolve_refs(parameters);
  829. tool_rules.push_back(builder.add_schema(name + "-call", {
  830. {"type", "object"},
  831. {"properties", json {
  832. {"name", json {{"const", name}}},
  833. {"arguments", parameters},
  834. }},
  835. {"required", json::array({"name", "arguments"})},
  836. }));
  837. });
  838. auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
  839. builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
  840. data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
  841. data.preserved_tokens = { "</tool_call>" };
  842. }, grammar_options);
  843. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  844. data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
  845. return data;
  846. }
  847. static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
  848. try {
  849. std::regex start_pattern(R"([\n\s]*<tool_call>)");
  850. std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
  851. std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
  852. auto end = input.end();
  853. std::sregex_iterator rend;
  854. std::sregex_iterator rit(input.begin(), end, start_pattern);
  855. if (rit == rend) {
  856. return {
  857. /* .role = */ "assistant",
  858. /* .content = */ input,
  859. /* .tool_calls = */ {},
  860. };
  861. }
  862. common_chat_msg result;
  863. result.role = "assistant";
  864. result.content = rit->prefix();
  865. auto it = rit->suffix().first;
  866. while (it != end) {
  867. json call;
  868. if (!parse_json(it, end, call)) {
  869. throw std::runtime_error("Failed to parse json tool call");
  870. }
  871. const auto & arguments = call.at("arguments");
  872. result.tool_calls.push_back({
  873. call.at("name"),
  874. arguments.dump(),
  875. // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
  876. /* id= */ "",
  877. });
  878. rit = {it, end, middle_pattern};
  879. if (rit != rend) {
  880. it = rit->suffix().first;
  881. } else {
  882. rit = {it, end, end_pattern};
  883. if (rit == rend) {
  884. throw std::runtime_error("Malformed input, missing </tool_call>");
  885. }
  886. break;
  887. }
  888. }
  889. return result;
  890. } catch (const std::exception & e) {
  891. return {
  892. /* .role = */ "assistant",
  893. /* .content = */ input,
  894. /* .tool_calls = */ {},
  895. };
  896. }
  897. }
  898. static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  899. common_chat_params data;
  900. data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
  901. data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  902. data.grammar_lazy = false;
  903. if (!inputs.json_schema.is_null()) {
  904. if (!inputs.grammar.empty()) {
  905. throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
  906. }
  907. data.grammar = json_schema_to_grammar(inputs.json_schema);
  908. } else {
  909. data.grammar = inputs.grammar;
  910. }
  911. return data;
  912. }
  913. common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
  914. const auto & src = tmpl.source();
  915. const auto & caps = tmpl.original_caps();
  916. if (inputs.tools.is_array()) {
  917. if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
  918. throw std::runtime_error("Cannot specify grammar with tools");
  919. }
  920. if (caps.supports_tool_calls && !caps.supports_tools) {
  921. LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
  922. }
  923. }
  924. // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
  925. if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
  926. return common_chat_params_init_deepseek_r1(tmpl, inputs);
  927. }
  928. // Command R7B: : use handler in all cases except json schema (thinking / tools).
  929. if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
  930. return common_chat_params_init_command_r7b(tmpl, inputs);
  931. }
  932. // Use generic handler when mixing tools + JSON schema.
  933. // TODO: support that mix in handlers below.
  934. if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
  935. return common_chat_params_init_generic(tmpl, inputs);
  936. }
  937. // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
  938. if (src.find(">>>all") != std::string::npos) {
  939. return common_chat_params_init_functionary_v3_2(tmpl, inputs);
  940. }
  941. // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
  942. if (src.find(" functools[") != std::string::npos) {
  943. return common_chat_params_init_firefunction_v2(tmpl, inputs);
  944. }
  945. // Plain handler (no tools)
  946. if (inputs.tools.is_null() || inputs.tool_choice == "none") {
  947. return common_chat_params_init_without_tools(tmpl, inputs);
  948. }
  949. // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
  950. if (src.find("<tool_call>") != std::string::npos) {
  951. return common_chat_params_init_hermes_2_pro(tmpl, inputs);
  952. }
  953. // Functionary v3.1 (w/ tools)
  954. if (src.find("<|start_header_id|>") != std::string::npos
  955. && src.find("<function=") != std::string::npos) {
  956. return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
  957. }
  958. // Llama 3.1, 3.2, 3.3 (w/ tools)
  959. if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
  960. auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
  961. return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
  962. }
  963. // Mistral Nemo (w/ tools)
  964. if (src.find("[TOOL_CALLS]") != std::string::npos) {
  965. return common_chat_params_init_mistral_nemo(tmpl, inputs);
  966. }
  967. // Generic fallback
  968. return common_chat_params_init_generic(tmpl, inputs);
  969. }
  970. static common_chat_msg common_chat_parse_content_only(const std::string & input) {
  971. return {
  972. /* .role = */ "assistant",
  973. /* .content = */ input,
  974. /* .tool_calls = */ {},
  975. };
  976. }
  977. common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
  978. switch (format) {
  979. case COMMON_CHAT_FORMAT_CONTENT_ONLY:
  980. return common_chat_parse_content_only(input);
  981. case COMMON_CHAT_FORMAT_GENERIC:
  982. return common_chat_parse_generic(input);
  983. case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
  984. return common_chat_parse_mistral_nemo(input);
  985. case COMMON_CHAT_FORMAT_LLAMA_3_X:
  986. return common_chat_parse_llama_3_1(input);
  987. case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
  988. return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
  989. case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
  990. return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
  991. case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
  992. return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
  993. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
  994. return common_chat_parse_functionary_v3_2(input);
  995. case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
  996. return common_chat_parse_functionary_v3_1_llama_3_1(input);
  997. case COMMON_CHAT_FORMAT_HERMES_2_PRO:
  998. return common_chat_parse_hermes_2_pro(input);
  999. case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
  1000. return common_chat_parse_firefunction_v2(input);
  1001. case COMMON_CHAT_FORMAT_COMMAND_R7B:
  1002. return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
  1003. case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
  1004. return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
  1005. default:
  1006. throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
  1007. }
  1008. }