test-grammar-parser.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #include "llama.h"
  5. #include "examples/grammar-parser.cpp"
  6. #include <cassert>
  7. int main()
  8. {
  9. grammar_parser::parse_state parsed_grammar;
  10. const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+
  11. expr ::= term ([-+*/] term)*
  12. term ::= [0-9]+)""";
  13. parsed_grammar = grammar_parser::parse(grammar_bytes);
  14. std::vector<std::pair<std::string, uint32_t>> expected = {
  15. {"expr", 2},
  16. {"expr_5", 5},
  17. {"expr_6", 6},
  18. {"root", 0},
  19. {"root_1", 1},
  20. {"root_4", 4},
  21. {"term", 3},
  22. {"term_7", 7},
  23. };
  24. uint32_t index = 0;
  25. for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
  26. {
  27. std::string key = it->first;
  28. uint32_t value = it->second;
  29. std::pair<std::string, uint32_t> expected_pair = expected[index];
  30. // pretty print error message before asserting
  31. if (expected_pair.first != key || expected_pair.second != value)
  32. {
  33. fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
  34. fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
  35. fprintf(stderr, "expected_pair != actual_pair\n");
  36. }
  37. assert(expected_pair.first == key && expected_pair.second == value);
  38. index++;
  39. }
  40. std::vector<llama_grammar_element> expected_rules = {
  41. {LLAMA_GRETYPE_RULE_REF, 4},
  42. {LLAMA_GRETYPE_END, 0},
  43. {LLAMA_GRETYPE_RULE_REF, 2},
  44. {LLAMA_GRETYPE_CHAR, 61},
  45. {LLAMA_GRETYPE_RULE_REF, 3},
  46. {LLAMA_GRETYPE_CHAR, 10},
  47. {LLAMA_GRETYPE_END, 0},
  48. {LLAMA_GRETYPE_RULE_REF, 3},
  49. {LLAMA_GRETYPE_RULE_REF, 6},
  50. {LLAMA_GRETYPE_END, 0},
  51. {LLAMA_GRETYPE_RULE_REF, 7},
  52. {LLAMA_GRETYPE_END, 0},
  53. {LLAMA_GRETYPE_RULE_REF, 1},
  54. {LLAMA_GRETYPE_RULE_REF, 4},
  55. {LLAMA_GRETYPE_ALT, 0},
  56. {LLAMA_GRETYPE_RULE_REF, 1},
  57. {LLAMA_GRETYPE_END, 0},
  58. {LLAMA_GRETYPE_CHAR, 45},
  59. {LLAMA_GRETYPE_CHAR_ALT, 43},
  60. {LLAMA_GRETYPE_CHAR_ALT, 42},
  61. {LLAMA_GRETYPE_CHAR_ALT, 47},
  62. {LLAMA_GRETYPE_RULE_REF, 3},
  63. {LLAMA_GRETYPE_END, 0},
  64. {LLAMA_GRETYPE_RULE_REF, 5},
  65. {LLAMA_GRETYPE_RULE_REF, 6},
  66. {LLAMA_GRETYPE_ALT, 0},
  67. {LLAMA_GRETYPE_END, 0},
  68. {LLAMA_GRETYPE_CHAR, 48},
  69. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  70. {LLAMA_GRETYPE_RULE_REF, 7},
  71. {LLAMA_GRETYPE_ALT, 0},
  72. {LLAMA_GRETYPE_CHAR, 48},
  73. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  74. {LLAMA_GRETYPE_END, 0},
  75. };
  76. index = 0;
  77. for (auto rule : parsed_grammar.rules)
  78. {
  79. // compare rule to expected rule
  80. for (uint32_t i = 0; i < rule.size(); i++)
  81. {
  82. llama_grammar_element element = rule[i];
  83. llama_grammar_element expected_element = expected_rules[index];
  84. // pretty print error message before asserting
  85. if (expected_element.type != element.type || expected_element.value != element.value)
  86. {
  87. fprintf(stderr, "index: %d\n", index);
  88. fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
  89. fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
  90. fprintf(stderr, "expected_element != actual_element\n");
  91. }
  92. assert(expected_element.type == element.type && expected_element.value == element.value);
  93. index++;
  94. }
  95. }
  96. const char *longer_grammar_bytes = R"""(
  97. root ::= (expr "=" ws term "\n")+
  98. expr ::= term ([-+*/] term)*
  99. term ::= ident | num | "(" ws expr ")" ws
  100. ident ::= [a-z] [a-z0-9_]* ws
  101. num ::= [0-9]+ ws
  102. ws ::= [ \t\n]*
  103. )""";
  104. parsed_grammar = grammar_parser::parse(longer_grammar_bytes);
  105. expected = {
  106. {"expr", 2},
  107. {"expr_6", 6},
  108. {"expr_7", 7},
  109. {"ident", 8},
  110. {"ident_10", 10},
  111. {"num", 9},
  112. {"num_11", 11},
  113. {"root", 0},
  114. {"root_1", 1},
  115. {"root_5", 5},
  116. {"term", 4},
  117. {"ws", 3},
  118. {"ws_12", 12},
  119. };
  120. index = 0;
  121. for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
  122. {
  123. std::string key = it->first;
  124. uint32_t value = it->second;
  125. std::pair<std::string, uint32_t> expected_pair = expected[index];
  126. // pretty print error message before asserting
  127. if (expected_pair.first != key || expected_pair.second != value)
  128. {
  129. fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
  130. fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
  131. fprintf(stderr, "expected_pair != actual_pair\n");
  132. }
  133. assert(expected_pair.first == key && expected_pair.second == value);
  134. index++;
  135. }
  136. expected_rules = {
  137. {LLAMA_GRETYPE_RULE_REF, 5},
  138. {LLAMA_GRETYPE_END, 0},
  139. {LLAMA_GRETYPE_RULE_REF, 2},
  140. {LLAMA_GRETYPE_CHAR, 61},
  141. {LLAMA_GRETYPE_RULE_REF, 3},
  142. {LLAMA_GRETYPE_RULE_REF, 4},
  143. {LLAMA_GRETYPE_CHAR, 10},
  144. {LLAMA_GRETYPE_END, 0},
  145. {LLAMA_GRETYPE_RULE_REF, 4},
  146. {LLAMA_GRETYPE_RULE_REF, 7},
  147. {LLAMA_GRETYPE_END, 0},
  148. {LLAMA_GRETYPE_RULE_REF, 12},
  149. {LLAMA_GRETYPE_END, 0},
  150. {LLAMA_GRETYPE_RULE_REF, 8},
  151. {LLAMA_GRETYPE_ALT, 0},
  152. {LLAMA_GRETYPE_RULE_REF, 9},
  153. {LLAMA_GRETYPE_ALT, 0},
  154. {LLAMA_GRETYPE_CHAR, 40},
  155. {LLAMA_GRETYPE_RULE_REF, 3},
  156. {LLAMA_GRETYPE_RULE_REF, 2},
  157. {LLAMA_GRETYPE_CHAR, 41},
  158. {LLAMA_GRETYPE_RULE_REF, 3},
  159. {LLAMA_GRETYPE_END, 0},
  160. {LLAMA_GRETYPE_RULE_REF, 1},
  161. {LLAMA_GRETYPE_RULE_REF, 5},
  162. {LLAMA_GRETYPE_ALT, 0},
  163. {LLAMA_GRETYPE_RULE_REF, 1},
  164. {LLAMA_GRETYPE_END, 0},
  165. {LLAMA_GRETYPE_CHAR, 45},
  166. {LLAMA_GRETYPE_CHAR_ALT, 43},
  167. {LLAMA_GRETYPE_CHAR_ALT, 42},
  168. {LLAMA_GRETYPE_CHAR_ALT, 47},
  169. {LLAMA_GRETYPE_RULE_REF, 4},
  170. {LLAMA_GRETYPE_END, 0},
  171. {LLAMA_GRETYPE_RULE_REF, 6},
  172. {LLAMA_GRETYPE_RULE_REF, 7},
  173. {LLAMA_GRETYPE_ALT, 0},
  174. {LLAMA_GRETYPE_END, 0},
  175. {LLAMA_GRETYPE_CHAR, 97},
  176. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  177. {LLAMA_GRETYPE_RULE_REF, 10},
  178. {LLAMA_GRETYPE_RULE_REF, 3},
  179. {LLAMA_GRETYPE_END, 0},
  180. {LLAMA_GRETYPE_RULE_REF, 11},
  181. {LLAMA_GRETYPE_RULE_REF, 3},
  182. {LLAMA_GRETYPE_END, 0},
  183. {LLAMA_GRETYPE_CHAR, 97},
  184. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  185. {LLAMA_GRETYPE_CHAR_ALT, 48},
  186. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  187. {LLAMA_GRETYPE_CHAR_ALT, 95},
  188. {LLAMA_GRETYPE_RULE_REF, 10},
  189. {LLAMA_GRETYPE_ALT, 0},
  190. {LLAMA_GRETYPE_END, 0},
  191. {LLAMA_GRETYPE_CHAR, 48},
  192. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  193. {LLAMA_GRETYPE_RULE_REF, 11},
  194. {LLAMA_GRETYPE_ALT, 0},
  195. {LLAMA_GRETYPE_CHAR, 48},
  196. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  197. {LLAMA_GRETYPE_END, 0},
  198. {LLAMA_GRETYPE_CHAR, 32},
  199. {LLAMA_GRETYPE_CHAR_ALT, 9},
  200. {LLAMA_GRETYPE_CHAR_ALT, 10},
  201. {LLAMA_GRETYPE_RULE_REF, 12},
  202. {LLAMA_GRETYPE_ALT, 0},
  203. {LLAMA_GRETYPE_END, 0},
  204. };
  205. index = 0;
  206. for (auto rule : parsed_grammar.rules)
  207. {
  208. // compare rule to expected rule
  209. for (uint32_t i = 0; i < rule.size(); i++)
  210. {
  211. llama_grammar_element element = rule[i];
  212. llama_grammar_element expected_element = expected_rules[index];
  213. // pretty print error message before asserting
  214. if (expected_element.type != element.type || expected_element.value != element.value)
  215. {
  216. fprintf(stderr, "index: %d\n", index);
  217. fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
  218. fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
  219. fprintf(stderr, "expected_element != actual_element\n");
  220. }
  221. assert(expected_element.type == element.type && expected_element.value == element.value);
  222. index++;
  223. }
  224. }
  225. return 0;
  226. }