value.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. #pragma once
  2. #include "string.h"
  3. #include <algorithm>
  4. #include <cstdint>
  5. #include <functional>
  6. #include <map>
  7. #include <memory>
  8. #include <set>
  9. #include <sstream>
  10. #include <string>
  11. #include <vector>
  12. namespace jinja {
  13. struct value_t;
  14. using value = std::shared_ptr<value_t>;
  15. // Helper to check the type of a value
  16. template<typename T>
  17. struct extract_pointee {
  18. using type = T;
  19. };
  20. template<typename U>
  21. struct extract_pointee<std::shared_ptr<U>> {
  22. using type = U;
  23. };
  24. template<typename T>
  25. bool is_val(const value & ptr) {
  26. using PointeeType = typename extract_pointee<T>::type;
  27. return dynamic_cast<const PointeeType*>(ptr.get()) != nullptr;
  28. }
  29. template<typename T>
  30. bool is_val(const value_t * ptr) {
  31. using PointeeType = typename extract_pointee<T>::type;
  32. return dynamic_cast<const PointeeType*>(ptr) != nullptr;
  33. }
  34. template<typename T, typename... Args>
  35. std::shared_ptr<typename extract_pointee<T>::type> mk_val(Args&&... args) {
  36. using PointeeType = typename extract_pointee<T>::type;
  37. return std::make_shared<PointeeType>(std::forward<Args>(args)...);
  38. }
  39. template<typename T>
  40. const typename extract_pointee<T>::type * cast_val(const value & ptr) {
  41. using PointeeType = typename extract_pointee<T>::type;
  42. return dynamic_cast<const PointeeType*>(ptr.get());
  43. }
  44. template<typename T>
  45. typename extract_pointee<T>::type * cast_val(value & ptr) {
  46. using PointeeType = typename extract_pointee<T>::type;
  47. return dynamic_cast<PointeeType*>(ptr.get());
  48. }
  49. // End Helper
  50. struct context; // forward declaration
  51. // for converting from JSON to jinja values
  52. // example input JSON:
  53. // {
  54. // "messages": [
  55. // {"role": "user", "content": "Hello!"},
  56. // {"role": "assistant", "content": "Hi there!"}
  57. // ],
  58. // "bos_token": "<s>",
  59. // "eos_token": "</s>",
  60. // }
  61. //
  62. // to mark strings as user input, wrap them in a special object:
  63. // {
  64. // "messages": [
  65. // {
  66. // "role": "user",
  67. // "content": {"__input__": "Hello!"} // this string is user input
  68. // },
  69. // ...
  70. // ],
  71. // }
  72. //
  73. // marking input can be useful for tracking data provenance
  74. // and preventing template injection attacks
  75. //
  76. // Note: T_JSON can be nlohmann::ordered_json
  77. template<typename T_JSON>
  78. void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
  79. //
  80. // base value type
  81. //
  82. struct func_args; // function argument values
  83. using func_handler = std::function<value(const func_args &)>;
  84. using func_builtins = std::map<std::string, func_handler>;
  85. enum value_compare_op { eq, ge, gt, lt, ne };
  86. bool value_compare(const value & a, const value & b, value_compare_op op);
  87. struct value_t {
  88. int64_t val_int;
  89. double val_flt;
  90. string val_str;
  91. bool val_bool;
  92. std::vector<value> val_arr;
  93. struct map {
  94. // once set to true, all keys must be numeric
  95. // caveat: we only allow either all numeric keys or all non-numeric keys
  96. // for now, this only applied to for_statement in case of iterating over object keys/items
  97. bool is_key_numeric = false;
  98. std::map<std::string, value> unordered;
  99. std::vector<std::pair<std::string, value>> ordered;
  100. void insert(const std::string & key, const value & val) {
  101. if (unordered.find(key) != unordered.end()) {
  102. // if key exists, remove from ordered list
  103. ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
  104. [&](const std::pair<std::string, value> & p) { return p.first == key; }),
  105. ordered.end());
  106. }
  107. unordered[key] = val;
  108. ordered.push_back({key, val});
  109. }
  110. } val_obj;
  111. func_handler val_func;
  112. // only used if ctx.is_get_stats = true
  113. struct stats_t {
  114. bool used = false;
  115. // ops can be builtin calls or operators: "array_access", "object_access"
  116. std::set<std::string> ops;
  117. } stats;
  118. value_t() = default;
  119. value_t(const value_t &) = default;
  120. virtual ~value_t() = default;
  121. virtual std::string type() const { return ""; }
  122. virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
  123. virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
  124. virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
  125. virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
  126. virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
  127. virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
  128. virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
  129. virtual bool is_none() const { return false; }
  130. virtual bool is_undefined() const { return false; }
  131. virtual const func_builtins & get_builtins() const {
  132. throw std::runtime_error("No builtins available for type " + type());
  133. }
  134. virtual bool has_key(const std::string & key) {
  135. return val_obj.unordered.find(key) != val_obj.unordered.end();
  136. }
  137. virtual value & at(const std::string & key, value & default_val) {
  138. auto it = val_obj.unordered.find(key);
  139. if (it == val_obj.unordered.end()) {
  140. return default_val;
  141. }
  142. return val_obj.unordered.at(key);
  143. }
  144. virtual value & at(const std::string & key) {
  145. auto it = val_obj.unordered.find(key);
  146. if (it == val_obj.unordered.end()) {
  147. throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
  148. }
  149. return val_obj.unordered.at(key);
  150. }
  151. virtual value & at(int64_t index, value & default_val) {
  152. if (index < 0) {
  153. index += val_arr.size();
  154. }
  155. if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
  156. return default_val;
  157. }
  158. return val_arr[index];
  159. }
  160. virtual value & at(int64_t index) {
  161. if (index < 0) {
  162. index += val_arr.size();
  163. }
  164. if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
  165. throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
  166. }
  167. return val_arr[index];
  168. }
  169. virtual std::string as_repr() const { return as_string().str(); }
  170. };
  171. //
  172. // primitive value types
  173. //
  174. struct value_int_t : public value_t {
  175. value_int_t(int64_t v) { val_int = v; }
  176. virtual std::string type() const override { return "Integer"; }
  177. virtual int64_t as_int() const override { return val_int; }
  178. virtual double as_float() const override { return static_cast<double>(val_int); }
  179. virtual string as_string() const override { return std::to_string(val_int); }
  180. virtual bool as_bool() const override {
  181. return val_int != 0;
  182. }
  183. virtual const func_builtins & get_builtins() const override;
  184. };
  185. using value_int = std::shared_ptr<value_int_t>;
  186. struct value_float_t : public value_t {
  187. value_float_t(double v) { val_flt = v; }
  188. virtual std::string type() const override { return "Float"; }
  189. virtual double as_float() const override { return val_flt; }
  190. virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
  191. virtual string as_string() const override {
  192. std::string out = std::to_string(val_flt);
  193. out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
  194. if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
  195. return out;
  196. }
  197. virtual bool as_bool() const override {
  198. return val_flt != 0.0;
  199. }
  200. virtual const func_builtins & get_builtins() const override;
  201. };
  202. using value_float = std::shared_ptr<value_float_t>;
  203. struct value_string_t : public value_t {
  204. value_string_t() { val_str = string(); }
  205. value_string_t(const std::string & v) { val_str = string(v); }
  206. value_string_t(const string & v) { val_str = v; }
  207. virtual std::string type() const override { return "String"; }
  208. virtual string as_string() const override { return val_str; }
  209. virtual std::string as_repr() const override {
  210. std::ostringstream ss;
  211. for (const auto & part : val_str.parts) {
  212. ss << (part.is_input ? "INPUT: " : "TMPL: ") << part.val << "\n";
  213. }
  214. return ss.str();
  215. }
  216. virtual bool as_bool() const override {
  217. return val_str.length() > 0;
  218. }
  219. virtual const func_builtins & get_builtins() const override;
  220. void mark_input() {
  221. val_str.mark_input();
  222. }
  223. };
  224. using value_string = std::shared_ptr<value_string_t>;
  225. struct value_bool_t : public value_t {
  226. value_bool_t(bool v) { val_bool = v; }
  227. virtual std::string type() const override { return "Boolean"; }
  228. virtual bool as_bool() const override { return val_bool; }
  229. virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
  230. virtual const func_builtins & get_builtins() const override;
  231. };
  232. using value_bool = std::shared_ptr<value_bool_t>;
  233. struct value_array_t : public value_t {
  234. value_array_t() = default;
  235. value_array_t(value & v) {
  236. val_arr = v->val_arr;
  237. }
  238. value_array_t(const std::vector<value> & arr) {
  239. val_arr = arr;
  240. }
  241. void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
  242. void push_back(const value & val) { val_arr.push_back(val); }
  243. void push_back(value && val) { val_arr.push_back(std::move(val)); }
  244. value pop_at(int64_t index) {
  245. if (index < 0) {
  246. index = static_cast<int64_t>(val_arr.size()) + index;
  247. }
  248. if (index < 0 || index >= static_cast<int64_t>(val_arr.size())) {
  249. throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
  250. }
  251. value val = val_arr.at(static_cast<size_t>(index));
  252. val_arr.erase(val_arr.begin() + index);
  253. return val;
  254. }
  255. virtual std::string type() const override { return "Array"; }
  256. virtual const std::vector<value> & as_array() const override { return val_arr; }
  257. virtual string as_string() const override {
  258. std::ostringstream ss;
  259. ss << "[";
  260. for (size_t i = 0; i < val_arr.size(); i++) {
  261. if (i > 0) ss << ", ";
  262. ss << val_arr.at(i)->as_repr();
  263. }
  264. ss << "]";
  265. return ss.str();
  266. }
  267. virtual bool as_bool() const override {
  268. return !val_arr.empty();
  269. }
  270. virtual const func_builtins & get_builtins() const override;
  271. };
  272. using value_array = std::shared_ptr<value_array_t>;
  273. struct value_object_t : public value_t {
  274. bool has_builtins = true; // context and loop objects do not have builtins
  275. value_object_t() = default;
  276. value_object_t(value & v) {
  277. val_obj = v->val_obj;
  278. }
  279. value_object_t(const std::map<std::string, value> & obj) {
  280. for (const auto & pair : obj) {
  281. val_obj.insert(pair.first, pair.second);
  282. }
  283. }
  284. value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
  285. for (const auto & pair : obj) {
  286. val_obj.insert(pair.first, pair.second);
  287. }
  288. }
  289. void insert(const std::string & key, const value & val) {
  290. val_obj.insert(key, val);
  291. }
  292. virtual std::string type() const override { return "Object"; }
  293. virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
  294. virtual bool as_bool() const override {
  295. return !val_obj.unordered.empty();
  296. }
  297. virtual const func_builtins & get_builtins() const override;
  298. };
  299. using value_object = std::shared_ptr<value_object_t>;
  300. //
  301. // null and undefined types
  302. //
  303. struct value_none_t : public value_t {
  304. virtual std::string type() const override { return "None"; }
  305. virtual bool is_none() const override { return true; }
  306. virtual bool as_bool() const override { return false; }
  307. virtual std::string as_repr() const override { return type(); }
  308. virtual const func_builtins & get_builtins() const override;
  309. };
  310. using value_none = std::shared_ptr<value_none_t>;
  311. struct value_undefined_t : public value_t {
  312. std::string hint; // for debugging, to indicate where undefined came from
  313. value_undefined_t(const std::string & h = "") : hint(h) {}
  314. virtual std::string type() const override { return hint.empty() ? "Undefined" : "Undefined (hint: '" + hint + "')"; }
  315. virtual bool is_undefined() const override { return true; }
  316. virtual bool as_bool() const override { return false; }
  317. virtual std::string as_repr() const override { return type(); }
  318. virtual const func_builtins & get_builtins() const override;
  319. };
  320. using value_undefined = std::shared_ptr<value_undefined_t>;
  321. //
  322. // function type
  323. //
  324. struct func_args {
  325. public:
  326. std::string func_name; // for error messages
  327. context & ctx;
  328. func_args(context & ctx) : ctx(ctx) {}
  329. value get_kwarg(const std::string & key, value default_val) const;
  330. value get_kwarg_or_pos(const std::string & key, size_t pos) const;
  331. value get_pos(size_t pos) const;
  332. value get_pos(size_t pos, value default_val) const;
  333. const std::vector<value> & get_args() const;
  334. size_t count() const { return args.size(); }
  335. void push_back(const value & val);
  336. void push_front(const value & val);
  337. void ensure_count(size_t min, size_t max = 999) const {
  338. size_t n = args.size();
  339. if (n < min || n > max) {
  340. throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n));
  341. }
  342. }
  343. template<typename T> void ensure_val(const value & ptr) const {
  344. if (!is_val<T>(ptr)) {
  345. throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type());
  346. }
  347. }
  348. void ensure_count(bool require0, bool require1, bool require2, bool require3) const {
  349. static auto bool_to_int = [](bool b) { return b ? 1 : 0; };
  350. size_t required = bool_to_int(require0) + bool_to_int(require1) + bool_to_int(require2) + bool_to_int(require3);
  351. ensure_count(required);
  352. }
  353. template<typename T0> void ensure_vals(bool required0 = true) const {
  354. ensure_count(required0, false, false, false);
  355. if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
  356. }
  357. template<typename T0, typename T1> void ensure_vals(bool required0 = true, bool required1 = true) const {
  358. ensure_count(required0, required1, false, false);
  359. if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
  360. if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
  361. }
  362. template<typename T0, typename T1, typename T2> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const {
  363. ensure_count(required0, required1, required2, false);
  364. if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
  365. if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
  366. if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
  367. }
  368. template<typename T0, typename T1, typename T2, typename T3> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true, bool required3 = true) const {
  369. ensure_count(required0, required1, required2, required3);
  370. if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
  371. if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
  372. if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
  373. if (required3 && args.size() > 3) ensure_val<T3>(args[3]);
  374. }
  375. private:
  376. std::vector<value> args;
  377. };
  378. struct value_func_t : public value_t {
  379. std::string name;
  380. value arg0; // bound "this" argument, if any
  381. value_func_t(const std::string & name, const func_handler & func) : name(name) {
  382. val_func = func;
  383. }
  384. value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) {
  385. val_func = func;
  386. }
  387. virtual value invoke(const func_args & args) const override {
  388. func_args new_args(args); // copy
  389. new_args.func_name = name;
  390. if (arg0) {
  391. new_args.push_front(arg0);
  392. }
  393. return val_func(new_args);
  394. }
  395. virtual std::string type() const override { return "Function"; }
  396. virtual std::string as_repr() const override { return type(); }
  397. };
  398. using value_func = std::shared_ptr<value_func_t>;
  399. // special value for kwarg
  400. struct value_kwarg_t : public value_t {
  401. std::string key;
  402. value val;
  403. value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
  404. virtual std::string type() const override { return "KwArg"; }
  405. virtual std::string as_repr() const override { return type(); }
  406. };
  407. using value_kwarg = std::shared_ptr<value_kwarg_t>;
  408. // utils
  409. const func_builtins & global_builtins();
  410. std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
  411. struct not_implemented_exception : public std::runtime_error {
  412. not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
  413. };
  414. } // namespace jinja