caps.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #include "value.h"
  2. #include "runtime.h"
  3. #include "caps.h"
  4. // note: the json dependency is only for defining input in a convenient way
  5. // we can remove it in the future when we figure out a better way to define inputs using jinja::value
  6. #include <nlohmann/json.hpp>
  7. #include <functional>
  8. #include <sstream>
  9. #define FILENAME "jinja-caps"
  10. using json = nlohmann::ordered_json;
  11. namespace jinja {
  12. using caps_json_fn = std::function<json()>;
  13. using caps_analyze_fn = std::function<void(bool, value &, value &)>;
  14. static void caps_try_execute(jinja::program & prog,
  15. const caps_json_fn & messages_fn,
  16. const caps_json_fn & tools_fn,
  17. const caps_analyze_fn & analyze_fn) {
  18. context ctx;
  19. ctx.is_get_stats = true;
  20. jinja::global_from_json(ctx, json{
  21. {"messages", messages_fn()},
  22. {"tools", tools_fn()},
  23. {"bos_token", ""},
  24. {"eos_token", ""},
  25. {"add_generation_prompt", true}
  26. }, true);
  27. auto messages = ctx.get_val("messages");
  28. auto tools = ctx.get_val("tools");
  29. bool success = false;
  30. try {
  31. jinja::runtime runtime(ctx);
  32. runtime.execute(prog);
  33. success = true;
  34. } catch (const std::exception & e) {
  35. JJ_DEBUG("Exception during execution: %s", e.what());
  36. // ignore exceptions during capability analysis
  37. }
  38. analyze_fn(success, messages, tools);
  39. }
  40. // for debugging only
  41. static void caps_print_stats(value & v, const std::string & path) {
  42. std::string ops;
  43. for (const auto & name : v->stats.ops) {
  44. ops += name + " ";
  45. }
  46. JJ_DEBUG("Value %s, type: %s %s, ops: %s",
  47. path.c_str(),
  48. v->type().c_str(),
  49. v->stats.used ? "(used)" : "",
  50. ops.c_str());
  51. }
  52. std::string caps::to_string() const {
  53. std::ostringstream ss;
  54. ss << "Caps(\n";
  55. ss << " requires_typed_content=" << requires_typed_content << "\n";
  56. ss << " supports_tools=" << supports_tools << "\n";
  57. ss << " supports_tool_calls=" << supports_tool_calls << "\n";
  58. ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
  59. ss << " supports_system_role=" << supports_system_role << "\n";
  60. ss << ")";
  61. return ss.str();
  62. }
  63. caps caps_get(jinja::program & prog) {
  64. caps result;
  65. static const auto has_op = [](value & v, const std::string & op_name) {
  66. return v->stats.ops.find(op_name) != v->stats.ops.end();
  67. };
  68. // case: typed content requirement
  69. caps_try_execute(
  70. prog,
  71. [&]() {
  72. // messages
  73. return json::array({
  74. {
  75. {"role", "user"},
  76. {"content", "content"}
  77. }
  78. });
  79. },
  80. [&]() {
  81. // tools
  82. return json{nullptr};
  83. },
  84. [&](bool, value & messages, value &) {
  85. auto & content = messages->at(0)->at("content");
  86. caps_print_stats(content, "messages[0].content");
  87. if (has_op(content, "selectattr") || has_op(content, "array_access")) {
  88. // accessed as an array
  89. result.requires_typed_content = true;
  90. }
  91. }
  92. );
  93. // case: system prompt support
  94. caps_try_execute(
  95. prog,
  96. [&]() {
  97. // messages
  98. return json::array({
  99. {
  100. {"role", "system"},
  101. {"content", "System message"}
  102. },
  103. {
  104. {"role", "user"},
  105. {"content", "User message"}
  106. },
  107. });
  108. },
  109. [&]() {
  110. // tools
  111. return json::array();
  112. },
  113. [&](bool, value & messages, value &) {
  114. auto & content = messages->at(0)->at("content");
  115. caps_print_stats(content, "messages[0].content");
  116. if (!content->stats.used) {
  117. result.supports_system_role = false;
  118. }
  119. }
  120. );
  121. // case: tools support
  122. caps_try_execute(
  123. prog,
  124. [&]() {
  125. // messages
  126. return json::array({
  127. {
  128. {"role", "user"},
  129. {"content", "User message"},
  130. },
  131. {
  132. {"role", "assistant"},
  133. {"content", "Assistant message"},
  134. {"tool_calls", json::array({
  135. {
  136. {"id", "call1"},
  137. {"type", "function"},
  138. {"function", {
  139. {"name", "tool1"},
  140. {"arguments", {
  141. {"arg", "value"}
  142. }}
  143. }}
  144. },
  145. {
  146. {"id", "call2"},
  147. {"type", "function"},
  148. {"function", {
  149. {"name", "tool2"},
  150. {"arguments", {
  151. {"arg", "value"}
  152. }}
  153. }}
  154. }
  155. })}
  156. },
  157. {
  158. {"role", "user"},
  159. {"content", "User message"},
  160. },
  161. });
  162. },
  163. [&]() {
  164. // tools
  165. return json::array({
  166. {
  167. {"name", "tool"},
  168. {"type", "function"},
  169. {"function", {
  170. {"name", "tool"},
  171. {"description", "Tool description"},
  172. {"parameters", {
  173. {"type", "object"},
  174. {"properties", {
  175. {"arg", {
  176. {"type", "string"},
  177. {"description", "Arg description"},
  178. }},
  179. }},
  180. {"required", json::array({ "arg" })},
  181. }},
  182. }},
  183. },
  184. });
  185. },
  186. [&](bool success, value & messages, value & tools) {
  187. if (!success) {
  188. result.supports_tool_calls = false;
  189. result.supports_tools = false;
  190. return;
  191. }
  192. auto & tool_name = tools->at(0)->at("function")->at("name");
  193. caps_print_stats(tool_name, "tools[0].function.name");
  194. if (!tool_name->stats.used) {
  195. result.supports_tools = false;
  196. }
  197. auto & tool_calls = messages->at(1)->at("tool_calls");;
  198. caps_print_stats(tool_calls, "messages[1].tool_calls");
  199. if (!tool_calls->stats.used) {
  200. result.supports_tool_calls = false;
  201. }
  202. // check for second tool call usage
  203. auto & tool_call_1 = tool_calls->at(1)->at("function");
  204. caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
  205. if (!tool_call_1->stats.used) {
  206. result.supports_parallel_tool_calls = false;
  207. }
  208. }
  209. );
  210. JJ_DEBUG("%s\n", result.to_string().c_str());
  211. return result;
  212. }
  213. } // namespace jinja