| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- #include "value.h"
- #include "runtime.h"
- #include "caps.h"
- // note: the json dependency is only for defining input in a convenient way
- // we can remove it in the future when we figure out a better way to define inputs using jinja::value
- #include <nlohmann/json.hpp>
- #include <functional>
- #include <sstream>
- #define FILENAME "jinja-caps"
- using json = nlohmann::ordered_json;
- namespace jinja {
- using caps_json_fn = std::function<json()>;
- using caps_analyze_fn = std::function<void(bool, value &, value &)>;
- static void caps_try_execute(jinja::program & prog,
- const caps_json_fn & messages_fn,
- const caps_json_fn & tools_fn,
- const caps_analyze_fn & analyze_fn) {
- context ctx;
- ctx.is_get_stats = true;
- jinja::global_from_json(ctx, json{
- {"messages", messages_fn()},
- {"tools", tools_fn()},
- {"bos_token", ""},
- {"eos_token", ""},
- {"add_generation_prompt", true}
- }, true);
- auto messages = ctx.get_val("messages");
- auto tools = ctx.get_val("tools");
- bool success = false;
- try {
- jinja::runtime runtime(ctx);
- runtime.execute(prog);
- success = true;
- } catch (const std::exception & e) {
- JJ_DEBUG("Exception during execution: %s", e.what());
- // ignore exceptions during capability analysis
- }
- analyze_fn(success, messages, tools);
- }
- // for debugging only
- static void caps_print_stats(value & v, const std::string & path) {
- std::string ops;
- for (const auto & name : v->stats.ops) {
- ops += name + " ";
- }
- JJ_DEBUG("Value %s, type: %s %s, ops: %s",
- path.c_str(),
- v->type().c_str(),
- v->stats.used ? "(used)" : "",
- ops.c_str());
- }
- std::string caps::to_string() const {
- std::ostringstream ss;
- ss << "Caps(\n";
- ss << " requires_typed_content=" << requires_typed_content << "\n";
- ss << " supports_tools=" << supports_tools << "\n";
- ss << " supports_tool_calls=" << supports_tool_calls << "\n";
- ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
- ss << " supports_system_role=" << supports_system_role << "\n";
- ss << ")";
- return ss.str();
- }
- caps caps_get(jinja::program & prog) {
- caps result;
- static const auto has_op = [](value & v, const std::string & op_name) {
- return v->stats.ops.find(op_name) != v->stats.ops.end();
- };
- // case: typed content requirement
- caps_try_execute(
- prog,
- [&]() {
- // messages
- return json::array({
- {
- {"role", "user"},
- {"content", "content"}
- }
- });
- },
- [&]() {
- // tools
- return json{nullptr};
- },
- [&](bool, value & messages, value &) {
- auto & content = messages->at(0)->at("content");
- caps_print_stats(content, "messages[0].content");
- if (has_op(content, "selectattr") || has_op(content, "array_access")) {
- // accessed as an array
- result.requires_typed_content = true;
- }
- }
- );
- // case: system prompt support
- caps_try_execute(
- prog,
- [&]() {
- // messages
- return json::array({
- {
- {"role", "system"},
- {"content", "System message"}
- },
- {
- {"role", "user"},
- {"content", "User message"}
- },
- });
- },
- [&]() {
- // tools
- return json::array();
- },
- [&](bool, value & messages, value &) {
- auto & content = messages->at(0)->at("content");
- caps_print_stats(content, "messages[0].content");
- if (!content->stats.used) {
- result.supports_system_role = false;
- }
- }
- );
- // case: tools support
- caps_try_execute(
- prog,
- [&]() {
- // messages
- return json::array({
- {
- {"role", "user"},
- {"content", "User message"},
- },
- {
- {"role", "assistant"},
- {"content", "Assistant message"},
- {"tool_calls", json::array({
- {
- {"id", "call1"},
- {"type", "function"},
- {"function", {
- {"name", "tool1"},
- {"arguments", {
- {"arg", "value"}
- }}
- }}
- },
- {
- {"id", "call2"},
- {"type", "function"},
- {"function", {
- {"name", "tool2"},
- {"arguments", {
- {"arg", "value"}
- }}
- }}
- }
- })}
- },
- {
- {"role", "user"},
- {"content", "User message"},
- },
- });
- },
- [&]() {
- // tools
- return json::array({
- {
- {"name", "tool"},
- {"type", "function"},
- {"function", {
- {"name", "tool"},
- {"description", "Tool description"},
- {"parameters", {
- {"type", "object"},
- {"properties", {
- {"arg", {
- {"type", "string"},
- {"description", "Arg description"},
- }},
- }},
- {"required", json::array({ "arg" })},
- }},
- }},
- },
- });
- },
- [&](bool success, value & messages, value & tools) {
- if (!success) {
- result.supports_tool_calls = false;
- result.supports_tools = false;
- return;
- }
- auto & tool_name = tools->at(0)->at("function")->at("name");
- caps_print_stats(tool_name, "tools[0].function.name");
- if (!tool_name->stats.used) {
- result.supports_tools = false;
- }
- auto & tool_calls = messages->at(1)->at("tool_calls");;
- caps_print_stats(tool_calls, "messages[1].tool_calls");
- if (!tool_calls->stats.used) {
- result.supports_tool_calls = false;
- }
- // check for second tool call usage
- auto & tool_call_1 = tool_calls->at(1)->at("function");
- caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
- if (!tool_call_1->stats.used) {
- result.supports_parallel_tool_calls = false;
- }
- }
- );
- JJ_DEBUG("%s\n", result.to_string().c_str());
- return result;
- }
- } // namespace jinja
|