parser.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. #include "lexer.h"
  2. #include "runtime.h"
  3. #include "parser.h"
  4. #include <algorithm>
  5. #include <memory>
  6. #include <stdexcept>
  7. #include <string>
  8. #include <vector>
  9. #define FILENAME "jinja-parser"
  10. namespace jinja {
  11. // Helper to check type without asserting (useful for logic)
  12. template<typename T>
  13. static bool is_type(const statement_ptr & ptr) {
  14. return dynamic_cast<const T*>(ptr.get()) != nullptr;
  15. }
  16. class parser {
  17. const std::vector<token> & tokens;
  18. size_t current = 0;
  19. std::string source; // for error reporting
  20. public:
  21. parser(const std::vector<token> & t, const std::string & src) : tokens(t), source(src) {}
  22. program parse() {
  23. statements body;
  24. while (current < tokens.size()) {
  25. body.push_back(parse_any());
  26. }
  27. return program(std::move(body));
  28. }
  29. // NOTE: start_pos is the token index, used for error reporting
  30. template<typename T, typename... Args>
  31. std::unique_ptr<T> mk_stmt(size_t start_pos, Args&&... args) {
  32. auto ptr = std::make_unique<T>(std::forward<Args>(args)...);
  33. assert(start_pos < tokens.size());
  34. ptr->pos = tokens[start_pos].pos;
  35. return ptr;
  36. }
  37. private:
  38. const token & peek(size_t offset = 0) const {
  39. if (current + offset >= tokens.size()) {
  40. static const token end_token{token::eof, "", 0};
  41. return end_token;
  42. }
  43. return tokens[current + offset];
  44. }
  45. token expect(token::type type, const std::string& error) {
  46. const auto & t = peek();
  47. if (t.t != type) {
  48. throw parser_exception("Parser Error: " + error + " (Got " + t.value + ")", source, t.pos);
  49. }
  50. current++;
  51. return t;
  52. }
  53. void expect_identifier(const std::string & name) {
  54. const auto & t = peek();
  55. if (t.t != token::identifier || t.value != name) {
  56. throw parser_exception("Expected identifier: " + name, source, t.pos);
  57. }
  58. current++;
  59. }
  60. bool is(token::type type) const {
  61. return peek().t == type;
  62. }
  63. bool is_identifier(const std::string & name) const {
  64. return peek().t == token::identifier && peek().value == name;
  65. }
  66. bool is_statement(const std::vector<std::string> & names) const {
  67. if (peek(0).t != token::open_statement || peek(1).t != token::identifier) {
  68. return false;
  69. }
  70. std::string val = peek(1).value;
  71. return std::find(names.begin(), names.end(), val) != names.end();
  72. }
  73. statement_ptr parse_any() {
  74. size_t start_pos = current;
  75. switch (peek().t) {
  76. case token::comment:
  77. return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
  78. case token::text:
  79. return mk_stmt<string_literal>(start_pos, tokens[current++].value);
  80. case token::open_statement:
  81. return parse_jinja_statement();
  82. case token::open_expression:
  83. return parse_jinja_expression();
  84. default:
  85. throw std::runtime_error("Unexpected token type");
  86. }
  87. }
  88. statement_ptr parse_jinja_expression() {
  89. // Consume {{ }} tokens
  90. expect(token::open_expression, "Expected {{");
  91. auto result = parse_expression();
  92. expect(token::close_expression, "Expected }}");
  93. return result;
  94. }
  95. statement_ptr parse_jinja_statement() {
  96. // Consume {% token
  97. expect(token::open_statement, "Expected {%");
  98. if (peek().t != token::identifier) {
  99. throw std::runtime_error("Unknown statement");
  100. }
  101. size_t start_pos = current;
  102. std::string name = peek().value;
  103. current++; // consume identifier
  104. statement_ptr result;
  105. if (name == "set") {
  106. result = parse_set_statement(start_pos);
  107. } else if (name == "if") {
  108. result = parse_if_statement(start_pos);
  109. // expect {% endif %}
  110. expect(token::open_statement, "Expected {%");
  111. expect_identifier("endif");
  112. expect(token::close_statement, "Expected %}");
  113. } else if (name == "macro") {
  114. result = parse_macro_statement(start_pos);
  115. // expect {% endmacro %}
  116. expect(token::open_statement, "Expected {%");
  117. expect_identifier("endmacro");
  118. expect(token::close_statement, "Expected %}");
  119. } else if (name == "for") {
  120. result = parse_for_statement(start_pos);
  121. // expect {% endfor %}
  122. expect(token::open_statement, "Expected {%");
  123. expect_identifier("endfor");
  124. expect(token::close_statement, "Expected %}");
  125. } else if (name == "break") {
  126. expect(token::close_statement, "Expected %}");
  127. result = mk_stmt<break_statement>(start_pos);
  128. } else if (name == "continue") {
  129. expect(token::close_statement, "Expected %}");
  130. result = mk_stmt<continue_statement>(start_pos);
  131. } else if (name == "call") {
  132. statements caller_args;
  133. // bool has_caller_args = false;
  134. if (is(token::open_paren)) {
  135. // Optional caller arguments, e.g. {% call(user) dump_users(...) %}
  136. caller_args = parse_args();
  137. // has_caller_args = true;
  138. }
  139. auto callee = parse_primary_expression();
  140. if (!is_type<identifier>(callee)) throw std::runtime_error("Expected identifier");
  141. auto call_args = parse_args();
  142. expect(token::close_statement, "Expected %}");
  143. statements body;
  144. while (!is_statement({"endcall"})) {
  145. body.push_back(parse_any());
  146. }
  147. expect(token::open_statement, "Expected {%");
  148. expect_identifier("endcall");
  149. expect(token::close_statement, "Expected %}");
  150. auto call_expr = mk_stmt<call_expression>(start_pos, std::move(callee), std::move(call_args));
  151. result = mk_stmt<call_statement>(start_pos, std::move(call_expr), std::move(caller_args), std::move(body));
  152. } else if (name == "filter") {
  153. auto filter_node = parse_primary_expression();
  154. if (is_type<identifier>(filter_node) && is(token::open_paren)) {
  155. filter_node = parse_call_expression(std::move(filter_node));
  156. }
  157. expect(token::close_statement, "Expected %}");
  158. statements body;
  159. while (!is_statement({"endfilter"})) {
  160. body.push_back(parse_any());
  161. }
  162. expect(token::open_statement, "Expected {%");
  163. expect_identifier("endfilter");
  164. expect(token::close_statement, "Expected %}");
  165. result = mk_stmt<filter_statement>(start_pos, std::move(filter_node), std::move(body));
  166. } else if (name == "generation" || name == "endgeneration") {
  167. // Ignore generation blocks (transformers-specific)
  168. // See https://github.com/huggingface/transformers/pull/30650 for more information.
  169. result = mk_stmt<noop_statement>(start_pos);
  170. current++;
  171. } else {
  172. throw std::runtime_error("Unknown statement: " + name);
  173. }
  174. return result;
  175. }
  176. statement_ptr parse_set_statement(size_t start_pos) {
  177. // NOTE: `set` acts as both declaration statement and assignment expression
  178. auto left = parse_expression_sequence();
  179. statement_ptr value = nullptr;
  180. statements body;
  181. if (is(token::equals)) {
  182. current++;
  183. value = parse_expression_sequence();
  184. } else {
  185. // parsing multiline set here
  186. expect(token::close_statement, "Expected %}");
  187. while (!is_statement({"endset"})) {
  188. body.push_back(parse_any());
  189. }
  190. expect(token::open_statement, "Expected {%");
  191. expect_identifier("endset");
  192. }
  193. expect(token::close_statement, "Expected %}");
  194. return mk_stmt<set_statement>(start_pos, std::move(left), std::move(value), std::move(body));
  195. }
  196. statement_ptr parse_if_statement(size_t start_pos) {
  197. auto test = parse_expression();
  198. expect(token::close_statement, "Expected %}");
  199. statements body;
  200. statements alternate;
  201. // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %}
  202. while (!is_statement({"elif", "else", "endif"})) {
  203. body.push_back(parse_any());
  204. }
  205. if (is_statement({"elif"})) {
  206. size_t pos0 = current;
  207. ++current; // consume {%
  208. ++current; // consume 'elif'
  209. alternate.push_back(parse_if_statement(pos0)); // nested If
  210. } else if (is_statement({"else"})) {
  211. ++current; // consume {%
  212. ++current; // consume 'else'
  213. expect(token::close_statement, "Expected %}");
  214. // keep going until we hit {% endif %}
  215. while (!is_statement({"endif"})) {
  216. alternate.push_back(parse_any());
  217. }
  218. }
  219. return mk_stmt<if_statement>(start_pos, std::move(test), std::move(body), std::move(alternate));
  220. }
  221. statement_ptr parse_macro_statement(size_t start_pos) {
  222. auto name = parse_primary_expression();
  223. auto args = parse_args();
  224. expect(token::close_statement, "Expected %}");
  225. statements body;
  226. // Keep going until we hit {% endmacro
  227. while (!is_statement({"endmacro"})) {
  228. body.push_back(parse_any());
  229. }
  230. return mk_stmt<macro_statement>(start_pos, std::move(name), std::move(args), std::move(body));
  231. }
  232. statement_ptr parse_expression_sequence(bool primary = false) {
  233. size_t start_pos = current;
  234. statements exprs;
  235. exprs.push_back(primary ? parse_primary_expression() : parse_expression());
  236. bool is_tuple = is(token::comma);
  237. while (is(token::comma)) {
  238. current++; // consume comma
  239. exprs.push_back(primary ? parse_primary_expression() : parse_expression());
  240. }
  241. return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
  242. }
  243. statement_ptr parse_for_statement(size_t start_pos) {
  244. // e.g., `message` in `for message in messages`
  245. auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
  246. if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
  247. current++;
  248. // `messages` in `for message in messages`
  249. auto iterable = parse_expression();
  250. expect(token::close_statement, "Expected %}");
  251. statements body;
  252. statements alternate;
  253. // Keep going until we hit {% endfor or {% else
  254. while (!is_statement({"endfor", "else"})) {
  255. body.push_back(parse_any());
  256. }
  257. if (is_statement({"else"})) {
  258. current += 2;
  259. expect(token::close_statement, "Expected %}");
  260. while (!is_statement({"endfor"})) {
  261. alternate.push_back(parse_any());
  262. }
  263. }
  264. return mk_stmt<for_statement>(
  265. start_pos,
  266. std::move(loop_var), std::move(iterable),
  267. std::move(body), std::move(alternate));
  268. }
  269. statement_ptr parse_expression() {
  270. // Choose parse function with lowest precedence
  271. return parse_if_expression();
  272. }
  273. statement_ptr parse_if_expression() {
  274. auto a = parse_logical_or_expression();
  275. if (is_identifier("if")) {
  276. // Ternary expression
  277. size_t start_pos = current;
  278. ++current; // consume 'if'
  279. auto test = parse_logical_or_expression();
  280. if (is_identifier("else")) {
  281. // Ternary expression with else
  282. size_t pos0 = current;
  283. ++current; // consume 'else'
  284. auto false_expr = parse_if_expression(); // recurse to support chained ternaries
  285. return mk_stmt<ternary_expression>(pos0, std::move(test), std::move(a), std::move(false_expr));
  286. } else {
  287. // Select expression on iterable
  288. return mk_stmt<select_expression>(start_pos, std::move(a), std::move(test));
  289. }
  290. }
  291. return a;
  292. }
  293. statement_ptr parse_logical_or_expression() {
  294. auto left = parse_logical_and_expression();
  295. while (is_identifier("or")) {
  296. size_t start_pos = current;
  297. token op = tokens[current++];
  298. left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
  299. }
  300. return left;
  301. }
  302. statement_ptr parse_logical_and_expression() {
  303. auto left = parse_logical_negation_expression();
  304. while (is_identifier("and")) {
  305. size_t start_pos = current;
  306. auto op = tokens[current++];
  307. left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
  308. }
  309. return left;
  310. }
  311. statement_ptr parse_logical_negation_expression() {
  312. // Try parse unary operators
  313. if (is_identifier("not")) {
  314. size_t start_pos = current;
  315. auto op = tokens[current++];
  316. return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
  317. }
  318. return parse_comparison_expression();
  319. }
  320. statement_ptr parse_comparison_expression() {
  321. // NOTE: membership has same precedence as comparison
  322. // e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana')))
  323. auto left = parse_additive_expression();
  324. while (true) {
  325. token op;
  326. size_t start_pos = current;
  327. if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
  328. op = {token::identifier, "not in", tokens[current].pos};
  329. current += 2;
  330. } else if (is_identifier("in")) {
  331. op = tokens[current++];
  332. } else if (is(token::comparison_binary_operator)) {
  333. op = tokens[current++];
  334. } else break;
  335. left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
  336. }
  337. return left;
  338. }
  339. statement_ptr parse_additive_expression() {
  340. auto left = parse_multiplicative_expression();
  341. while (is(token::additive_binary_operator)) {
  342. size_t start_pos = current;
  343. auto op = tokens[current++];
  344. left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
  345. }
  346. return left;
  347. }
  348. statement_ptr parse_multiplicative_expression() {
  349. auto left = parse_test_expression();
  350. while (is(token::multiplicative_binary_operator)) {
  351. size_t start_pos = current;
  352. auto op = tokens[current++];
  353. left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
  354. }
  355. return left;
  356. }
  357. statement_ptr parse_test_expression() {
  358. auto operand = parse_filter_expression();
  359. while (is_identifier("is")) {
  360. size_t start_pos = current;
  361. current++;
  362. bool negate = false;
  363. if (is_identifier("not")) { current++; negate = true; }
  364. auto test_id = parse_primary_expression();
  365. // FIXME: tests can also be expressed like this: if x is eq 3
  366. if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
  367. operand = mk_stmt<test_expression>(start_pos, std::move(operand), negate, std::move(test_id));
  368. }
  369. return operand;
  370. }
  371. statement_ptr parse_filter_expression() {
  372. auto operand = parse_call_member_expression();
  373. while (is(token::pipe)) {
  374. size_t start_pos = current;
  375. current++;
  376. auto filter = parse_primary_expression();
  377. if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
  378. operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
  379. }
  380. return operand;
  381. }
  382. statement_ptr parse_call_member_expression() {
  383. // Handle member expressions recursively
  384. auto member = parse_member_expression(parse_primary_expression());
  385. return is(token::open_paren)
  386. ? parse_call_expression(std::move(member)) // foo.x()
  387. : std::move(member);
  388. }
  389. statement_ptr parse_call_expression(statement_ptr callee) {
  390. size_t start_pos = current;
  391. auto expr = mk_stmt<call_expression>(start_pos, std::move(callee), parse_args());
  392. auto member = parse_member_expression(std::move(expr)); // foo.x().y
  393. return is(token::open_paren)
  394. ? parse_call_expression(std::move(member)) // foo.x()()
  395. : std::move(member);
  396. }
  397. statements parse_args() {
  398. // comma-separated arguments list
  399. expect(token::open_paren, "Expected (");
  400. statements args;
  401. while (!is(token::close_paren)) {
  402. statement_ptr arg;
  403. // unpacking: *expr
  404. if (peek().t == token::multiplicative_binary_operator && peek().value == "*") {
  405. size_t start_pos = current;
  406. ++current; // consume *
  407. arg = mk_stmt<spread_expression>(start_pos, parse_expression());
  408. } else {
  409. arg = parse_expression();
  410. if (is(token::equals)) {
  411. // keyword argument
  412. // e.g., func(x = 5, y = a or b)
  413. size_t start_pos = current;
  414. ++current; // consume equals
  415. arg = mk_stmt<keyword_argument_expression>(start_pos, std::move(arg), parse_expression());
  416. }
  417. }
  418. args.push_back(std::move(arg));
  419. if (is(token::comma)) {
  420. ++current; // consume comma
  421. }
  422. }
  423. expect(token::close_paren, "Expected )");
  424. return args;
  425. }
  426. statement_ptr parse_member_expression(statement_ptr object) {
  427. size_t start_pos = current;
  428. while (is(token::dot) || is(token::open_square_bracket)) {
  429. auto op = tokens[current++];
  430. bool computed = op.t == token::open_square_bracket;
  431. statement_ptr prop;
  432. if (computed) {
  433. prop = parse_member_expression_arguments();
  434. expect(token::close_square_bracket, "Expected ]");
  435. } else {
  436. prop = parse_primary_expression();
  437. }
  438. object = mk_stmt<member_expression>(start_pos, std::move(object), std::move(prop), computed);
  439. }
  440. return object;
  441. }
  442. statement_ptr parse_member_expression_arguments() {
  443. // NOTE: This also handles slice expressions colon-separated arguments list
  444. // e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3]
  445. statements slices;
  446. bool is_slice = false;
  447. size_t start_pos = current;
  448. while (!is(token::close_square_bracket)) {
  449. if (is(token::colon)) {
  450. // A case where a default is used
  451. // e.g., [:2] will be parsed as [undefined, 2]
  452. slices.push_back(nullptr);
  453. ++current; // consume colon
  454. is_slice = true;
  455. } else {
  456. slices.push_back(parse_expression());
  457. if (is(token::colon)) {
  458. ++current; // consume colon after expression, if it exists
  459. is_slice = true;
  460. }
  461. }
  462. }
  463. if (is_slice) {
  464. statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr;
  465. statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr;
  466. statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
  467. return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
  468. }
  469. return std::move(slices[0]);
  470. }
  471. statement_ptr parse_primary_expression() {
  472. size_t start_pos = current;
  473. auto t = tokens[current++];
  474. switch (t.t) {
  475. case token::numeric_literal:
  476. if (t.value.find('.') != std::string::npos) {
  477. return mk_stmt<float_literal>(start_pos, std::stod(t.value));
  478. } else {
  479. return mk_stmt<integer_literal>(start_pos, std::stoll(t.value));
  480. }
  481. case token::string_literal: {
  482. std::string val = t.value;
  483. while (is(token::string_literal)) {
  484. val += tokens[current++].value;
  485. }
  486. return mk_stmt<string_literal>(start_pos, val);
  487. }
  488. case token::identifier:
  489. return mk_stmt<identifier>(start_pos, t.value);
  490. case token::open_paren: {
  491. auto expr = parse_expression_sequence();
  492. expect(token::close_paren, "Expected )");
  493. return expr;
  494. }
  495. case token::open_square_bracket: {
  496. statements vals;
  497. while (!is(token::close_square_bracket)) {
  498. vals.push_back(parse_expression());
  499. if (is(token::comma)) current++;
  500. }
  501. current++;
  502. return mk_stmt<array_literal>(start_pos, std::move(vals));
  503. }
  504. case token::open_curly_bracket: {
  505. std::vector<std::pair<statement_ptr, statement_ptr>> pairs;
  506. while (!is(token::close_curly_bracket)) {
  507. auto key = parse_expression();
  508. expect(token::colon, "Expected :");
  509. pairs.push_back({std::move(key), parse_expression()});
  510. if (is(token::comma)) current++;
  511. }
  512. current++;
  513. return mk_stmt<object_literal>(start_pos, std::move(pairs));
  514. }
  515. default:
  516. throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t));
  517. }
  518. }
  519. };
  520. program parse_from_tokens(const lexer_result & lexer_res) {
  521. return parser(lexer_res.tokens, lexer_res.source).parse();
  522. }
  523. } // namespace jinja