1
0

caps.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #include "value.h"
  2. #include "runtime.h"
  3. #include "caps.h"
  4. // note: the json dependency is only for defining input in a convenient way
  5. // we can remove it in the future when we figure out a better way to define inputs using jinja::value
  6. #include <nlohmann/json.hpp>
  7. #include <functional>
  8. #include <sstream>
  9. #define FILENAME "jinja-caps"
  10. using json = nlohmann::ordered_json;
  11. namespace jinja {
  12. using caps_json_fn = std::function<json()>;
  13. using caps_analyze_fn = std::function<void(bool, value &, value &)>;
  14. static void caps_try_execute(jinja::program & prog,
  15. const caps_json_fn & messages_fn,
  16. const caps_json_fn & tools_fn,
  17. const caps_analyze_fn & analyze_fn) {
  18. context ctx;
  19. ctx.is_get_stats = true;
  20. jinja::global_from_json(ctx, json{
  21. {"messages", messages_fn()},
  22. {"tools", tools_fn()},
  23. {"bos_token", ""},
  24. {"eos_token", ""},
  25. {"add_generation_prompt", true}
  26. }, true);
  27. auto messages = ctx.get_val("messages");
  28. auto tools = ctx.get_val("tools");
  29. bool success = false;
  30. try {
  31. jinja::runtime runtime(ctx);
  32. runtime.execute(prog);
  33. success = true;
  34. } catch (const std::exception & e) {
  35. JJ_DEBUG("Exception during execution: %s", e.what());
  36. // ignore exceptions during capability analysis
  37. }
  38. analyze_fn(success, messages, tools);
  39. }
  40. // for debugging only
  41. static void caps_print_stats(value & v, const std::string & path) {
  42. std::string ops;
  43. for (const auto & name : v->stats.ops) {
  44. ops += name + " ";
  45. }
  46. JJ_DEBUG("Value %s, type: %s %s, ops: %s",
  47. path.c_str(),
  48. v->type().c_str(),
  49. v->stats.used ? "(used)" : "",
  50. ops.c_str());
  51. }
  52. std::map<std::string, bool> caps::to_map() const {
  53. return {
  54. {"requires_typed_content", requires_typed_content},
  55. {"supports_tools", supports_tools},
  56. {"supports_tool_calls", supports_tool_calls},
  57. {"supports_parallel_tool_calls", supports_parallel_tool_calls},
  58. {"supports_system_role", supports_system_role},
  59. {"supports_preserve_reasoning", supports_preserve_reasoning},
  60. };
  61. }
  62. std::string caps::to_string() const {
  63. std::ostringstream ss;
  64. ss << "Caps(\n";
  65. for (const auto & [key, value] : to_map()) {
  66. ss << " " << key << "=" << (value ? "true" : "false") << "\n";
  67. }
  68. ss << ")";
  69. return ss.str();
  70. }
  71. caps caps_get(jinja::program & prog) {
  72. caps result;
  73. static const auto has_op = [](value & v, const std::string & op_name) {
  74. return v->stats.ops.find(op_name) != v->stats.ops.end();
  75. };
  76. // case: typed content requirement
  77. caps_try_execute(
  78. prog,
  79. [&]() {
  80. // messages
  81. return json::array({
  82. {
  83. {"role", "user"},
  84. {"content", "content"}
  85. }
  86. });
  87. },
  88. [&]() {
  89. // tools
  90. return json{nullptr};
  91. },
  92. [&](bool, value & messages, value &) {
  93. auto & content = messages->at(0)->at("content");
  94. caps_print_stats(content, "messages[0].content");
  95. if (has_op(content, "selectattr") || has_op(content, "array_access")) {
  96. // accessed as an array
  97. result.requires_typed_content = true;
  98. }
  99. }
  100. );
  101. // case: system prompt support
  102. caps_try_execute(
  103. prog,
  104. [&]() {
  105. // messages
  106. return json::array({
  107. {
  108. {"role", "system"},
  109. {"content", "System message"}
  110. },
  111. {
  112. {"role", "user"},
  113. {"content", "User message"}
  114. },
  115. });
  116. },
  117. [&]() {
  118. // tools
  119. return json::array();
  120. },
  121. [&](bool, value & messages, value &) {
  122. auto & content = messages->at(0)->at("content");
  123. caps_print_stats(content, "messages[0].content");
  124. if (!content->stats.used) {
  125. result.supports_system_role = false;
  126. }
  127. }
  128. );
  129. // case: tools support
  130. caps_try_execute(
  131. prog,
  132. [&]() {
  133. // messages
  134. return json::array({
  135. {
  136. {"role", "user"},
  137. {"content", "User message"},
  138. },
  139. {
  140. {"role", "assistant"},
  141. {"content", "Assistant message"},
  142. {"tool_calls", json::array({
  143. {
  144. {"id", "call1"},
  145. {"type", "function"},
  146. {"function", {
  147. {"name", "tool1"},
  148. {"arguments", {
  149. {"arg", "value"}
  150. }}
  151. }}
  152. },
  153. {
  154. {"id", "call2"},
  155. {"type", "function"},
  156. {"function", {
  157. {"name", "tool2"},
  158. {"arguments", {
  159. {"arg", "value"}
  160. }}
  161. }}
  162. }
  163. })}
  164. },
  165. {
  166. {"role", "user"},
  167. {"content", "User message"},
  168. },
  169. });
  170. },
  171. [&]() {
  172. // tools
  173. return json::array({
  174. {
  175. {"name", "tool"},
  176. {"type", "function"},
  177. {"function", {
  178. {"name", "tool"},
  179. {"description", "Tool description"},
  180. {"parameters", {
  181. {"type", "object"},
  182. {"properties", {
  183. {"arg", {
  184. {"type", "string"},
  185. {"description", "Arg description"},
  186. }},
  187. }},
  188. {"required", json::array({ "arg" })},
  189. }},
  190. }},
  191. },
  192. });
  193. },
  194. [&](bool success, value & messages, value & tools) {
  195. if (!success) {
  196. result.supports_tool_calls = false;
  197. result.supports_tools = false;
  198. return;
  199. }
  200. auto & tool_name = tools->at(0)->at("function")->at("name");
  201. caps_print_stats(tool_name, "tools[0].function.name");
  202. if (!tool_name->stats.used) {
  203. result.supports_tools = false;
  204. }
  205. auto & tool_calls = messages->at(1)->at("tool_calls");;
  206. caps_print_stats(tool_calls, "messages[1].tool_calls");
  207. if (!tool_calls->stats.used) {
  208. result.supports_tool_calls = false;
  209. }
  210. // check for second tool call usage
  211. auto & tool_call_1 = tool_calls->at(1)->at("function");
  212. caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
  213. if (!tool_call_1->stats.used) {
  214. result.supports_parallel_tool_calls = false;
  215. }
  216. }
  217. );
  218. // case: preserve reasoning content in chat history
  219. caps_try_execute(
  220. prog,
  221. [&]() {
  222. // messages
  223. return json::array({
  224. {
  225. {"role", "user"},
  226. {"content", "User message"}
  227. },
  228. {
  229. {"role", "assistant"},
  230. {"content", "Assistant message"},
  231. {"reasoning_content", "Reasoning content"}
  232. },
  233. {
  234. {"role", "user"},
  235. {"content", "User message"}
  236. },
  237. });
  238. },
  239. [&]() {
  240. // tools
  241. return json::array();
  242. },
  243. [&](bool, value & messages, value &) {
  244. auto & content = messages->at(1)->at("reasoning_content");
  245. caps_print_stats(content, "messages[1].reasoning_content");
  246. if (content->stats.used) {
  247. result.supports_preserve_reasoning = true;
  248. }
  249. }
  250. );
  251. JJ_DEBUG("%s\n", result.to_string().c_str());
  252. return result;
  253. }
  254. } // namespace jinja