test-llama-grammar.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #include "llama.cpp" // TODO: not great
  5. #include "grammar-parser.h"
  6. #include <cassert>
  7. int main()
  8. {
  9. grammar_parser::parse_state parsed_grammar;
  10. std::vector<std::pair<std::string, uint32_t>> expected = {
  11. {"expr", 2},
  12. {"expr_6", 6},
  13. {"expr_7", 7},
  14. {"ident", 8},
  15. {"ident_10", 10},
  16. {"num", 9},
  17. {"num_11", 11},
  18. {"root", 0},
  19. {"root_1", 1},
  20. {"root_5", 5},
  21. {"term", 4},
  22. {"ws", 3},
  23. {"ws_12", 12},
  24. };
  25. std::vector<std::vector<llama_grammar_element>> expected_rules = {
  26. {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
  27. {
  28. {LLAMA_GRETYPE_RULE_REF, 2},
  29. {LLAMA_GRETYPE_CHAR, 61},
  30. {LLAMA_GRETYPE_RULE_REF, 3},
  31. {LLAMA_GRETYPE_RULE_REF, 4},
  32. {LLAMA_GRETYPE_CHAR, 10},
  33. {LLAMA_GRETYPE_END, 0},
  34. },
  35. {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
  36. {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
  37. {
  38. {LLAMA_GRETYPE_RULE_REF, 8},
  39. {LLAMA_GRETYPE_ALT, 0},
  40. {LLAMA_GRETYPE_RULE_REF, 9},
  41. {LLAMA_GRETYPE_ALT, 0},
  42. {LLAMA_GRETYPE_CHAR, 40},
  43. {LLAMA_GRETYPE_RULE_REF, 3},
  44. {LLAMA_GRETYPE_RULE_REF, 2},
  45. {LLAMA_GRETYPE_CHAR, 41},
  46. {LLAMA_GRETYPE_RULE_REF, 3},
  47. {LLAMA_GRETYPE_END, 0},
  48. },
  49. {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
  50. {
  51. {LLAMA_GRETYPE_CHAR, 45},
  52. {LLAMA_GRETYPE_CHAR_ALT, 43},
  53. {LLAMA_GRETYPE_CHAR_ALT, 42},
  54. {LLAMA_GRETYPE_CHAR_ALT, 47},
  55. {LLAMA_GRETYPE_RULE_REF, 4},
  56. {LLAMA_GRETYPE_END, 0},
  57. },
  58. {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
  59. {
  60. {LLAMA_GRETYPE_CHAR, 97},
  61. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  62. {LLAMA_GRETYPE_RULE_REF, 10},
  63. {LLAMA_GRETYPE_RULE_REF, 3},
  64. {LLAMA_GRETYPE_END, 0},
  65. },
  66. {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
  67. {
  68. {LLAMA_GRETYPE_CHAR, 97},
  69. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  70. {LLAMA_GRETYPE_CHAR_ALT, 48},
  71. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  72. {LLAMA_GRETYPE_CHAR_ALT, 95},
  73. {LLAMA_GRETYPE_RULE_REF, 10},
  74. {LLAMA_GRETYPE_ALT, 0},
  75. {LLAMA_GRETYPE_END, 0},
  76. },
  77. {
  78. {LLAMA_GRETYPE_CHAR, 48},
  79. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  80. {LLAMA_GRETYPE_RULE_REF, 11},
  81. {LLAMA_GRETYPE_ALT, 0},
  82. {LLAMA_GRETYPE_CHAR, 48},
  83. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  84. {LLAMA_GRETYPE_END, 0},
  85. },
  86. {
  87. {LLAMA_GRETYPE_CHAR, 32},
  88. {LLAMA_GRETYPE_CHAR_ALT, 9},
  89. {LLAMA_GRETYPE_CHAR_ALT, 10},
  90. {LLAMA_GRETYPE_RULE_REF, 12},
  91. {LLAMA_GRETYPE_ALT, 0},
  92. {LLAMA_GRETYPE_END, 0},
  93. },
  94. };
  95. for (auto pair : expected)
  96. {
  97. parsed_grammar.symbol_ids[pair.first] = pair.second;
  98. }
  99. for (auto rule : expected_rules)
  100. {
  101. parsed_grammar.rules.emplace_back();
  102. for (auto element : rule)
  103. {
  104. parsed_grammar.rules.back().push_back(element);
  105. }
  106. }
  107. llama_grammar *grammar = NULL;
  108. std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
  109. grammar = llama_grammar_init(
  110. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  111. std::vector<std::vector<llama_grammar_element>> expected_stacks = {
  112. {
  113. {LLAMA_GRETYPE_RULE_REF, 5},
  114. {LLAMA_GRETYPE_CHAR, 61},
  115. {LLAMA_GRETYPE_RULE_REF, 7},
  116. {LLAMA_GRETYPE_CHAR, 97},
  117. },
  118. {
  119. {LLAMA_GRETYPE_RULE_REF, 5},
  120. {LLAMA_GRETYPE_CHAR, 61},
  121. {LLAMA_GRETYPE_RULE_REF, 7},
  122. {LLAMA_GRETYPE_RULE_REF, 3},
  123. {LLAMA_GRETYPE_CHAR, 48},
  124. },
  125. {
  126. {LLAMA_GRETYPE_RULE_REF, 5},
  127. {LLAMA_GRETYPE_CHAR, 61},
  128. {LLAMA_GRETYPE_RULE_REF, 7},
  129. {LLAMA_GRETYPE_RULE_REF, 3},
  130. {LLAMA_GRETYPE_CHAR, 48},
  131. },
  132. {
  133. {LLAMA_GRETYPE_RULE_REF, 5},
  134. {LLAMA_GRETYPE_CHAR, 61},
  135. {LLAMA_GRETYPE_RULE_REF, 7},
  136. {LLAMA_GRETYPE_CHAR, 40},
  137. },
  138. {
  139. {LLAMA_GRETYPE_CHAR, 61},
  140. {LLAMA_GRETYPE_RULE_REF, 7},
  141. {LLAMA_GRETYPE_CHAR, 97},
  142. },
  143. {
  144. {LLAMA_GRETYPE_CHAR, 61},
  145. {LLAMA_GRETYPE_RULE_REF, 7},
  146. {LLAMA_GRETYPE_RULE_REF, 3},
  147. {LLAMA_GRETYPE_CHAR, 48},
  148. },
  149. {
  150. {LLAMA_GRETYPE_CHAR, 61},
  151. {LLAMA_GRETYPE_RULE_REF, 7},
  152. {LLAMA_GRETYPE_RULE_REF, 3},
  153. {LLAMA_GRETYPE_CHAR, 48},
  154. },
  155. {
  156. {LLAMA_GRETYPE_CHAR, 61},
  157. {LLAMA_GRETYPE_RULE_REF, 7},
  158. {LLAMA_GRETYPE_CHAR, 40},
  159. }};
  160. auto index = 0;
  161. for (auto stack : grammar->stacks)
  162. {
  163. // compare stack to expected_stack
  164. for (uint32_t i = 0; i < stack.size(); i++)
  165. {
  166. auto element = stack[i];
  167. auto expected_element = expected_stacks[index][i];
  168. // pretty print error message before asserting
  169. if (expected_element.type != element->type || expected_element.value != element->value)
  170. {
  171. fprintf(stderr, "index: %d\n", index);
  172. fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
  173. fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
  174. fprintf(stderr, "expected_element != actual_element\n");
  175. }
  176. assert(expected_element.type == element->type && expected_element.value == element->value);
  177. }
  178. index++;
  179. }
  180. std::vector<llama_grammar_candidate> next_candidates;
  181. next_candidates.resize(24);
  182. for (size_t i = 0; i < 24; ++i)
  183. {
  184. uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
  185. cp[0] = 37 + i;
  186. cp[1] = 0;
  187. next_candidates[i] = {i, cp, {}};
  188. }
  189. std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
  190. {
  191. {0, 37},
  192. {1, 38},
  193. {2, 39},
  194. {3, 40},
  195. {4, 41},
  196. {5, 42},
  197. {6, 43},
  198. {7, 44},
  199. {8, 45},
  200. {9, 46},
  201. {10, 47},
  202. {11, 48},
  203. {12, 49},
  204. {13, 50},
  205. {14, 51},
  206. {15, 52},
  207. {16, 53},
  208. {17, 54},
  209. {18, 55},
  210. {19, 56},
  211. {20, 57},
  212. {21, 58},
  213. {22, 59},
  214. {23, 60},
  215. },
  216. {
  217. {0, 37},
  218. {1, 38},
  219. {2, 39},
  220. {3, 40},
  221. {4, 41},
  222. {5, 42},
  223. {6, 43},
  224. {7, 44},
  225. {8, 45},
  226. {9, 46},
  227. {10, 47},
  228. {21, 58},
  229. {22, 59},
  230. {23, 60},
  231. },
  232. {
  233. {0, 37},
  234. {1, 38},
  235. {2, 39},
  236. {3, 40},
  237. {4, 41},
  238. {5, 42},
  239. {6, 43},
  240. {7, 44},
  241. {8, 45},
  242. {9, 46},
  243. {10, 47},
  244. {21, 58},
  245. {22, 59},
  246. {23, 60},
  247. },
  248. {
  249. {0, 37},
  250. {1, 38},
  251. {2, 39},
  252. {4, 41},
  253. {5, 42},
  254. {6, 43},
  255. {7, 44},
  256. {8, 45},
  257. {9, 46},
  258. {10, 47},
  259. {11, 48},
  260. {12, 49},
  261. {13, 50},
  262. {14, 51},
  263. {15, 52},
  264. {16, 53},
  265. {17, 54},
  266. {18, 55},
  267. {19, 56},
  268. {20, 57},
  269. {21, 58},
  270. {22, 59},
  271. {23, 60},
  272. },
  273. {
  274. {0, 37},
  275. {1, 38},
  276. {2, 39},
  277. {3, 40},
  278. {4, 41},
  279. {5, 42},
  280. {6, 43},
  281. {7, 44},
  282. {8, 45},
  283. {9, 46},
  284. {10, 47},
  285. {11, 48},
  286. {12, 49},
  287. {13, 50},
  288. {14, 51},
  289. {15, 52},
  290. {16, 53},
  291. {17, 54},
  292. {18, 55},
  293. {19, 56},
  294. {20, 57},
  295. {21, 58},
  296. {22, 59},
  297. {23, 60},
  298. },
  299. {
  300. {0, 37},
  301. {1, 38},
  302. {2, 39},
  303. {3, 40},
  304. {4, 41},
  305. {5, 42},
  306. {6, 43},
  307. {7, 44},
  308. {8, 45},
  309. {9, 46},
  310. {10, 47},
  311. {21, 58},
  312. {22, 59},
  313. {23, 60},
  314. },
  315. {
  316. {0, 37},
  317. {1, 38},
  318. {2, 39},
  319. {3, 40},
  320. {4, 41},
  321. {5, 42},
  322. {6, 43},
  323. {7, 44},
  324. {8, 45},
  325. {9, 46},
  326. {10, 47},
  327. {21, 58},
  328. {22, 59},
  329. {23, 60},
  330. },
  331. {
  332. {0, 37},
  333. {1, 38},
  334. {2, 39},
  335. {4, 41},
  336. {5, 42},
  337. {6, 43},
  338. {7, 44},
  339. {8, 45},
  340. {9, 46},
  341. {10, 47},
  342. {11, 48},
  343. {12, 49},
  344. {13, 50},
  345. {14, 51},
  346. {15, 52},
  347. {16, 53},
  348. {17, 54},
  349. {18, 55},
  350. {19, 56},
  351. {20, 57},
  352. {21, 58},
  353. {22, 59},
  354. {23, 60},
  355. },
  356. };
  357. std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
  358. std::vector<std::vector<llama_grammar_candidate>> all_rejects;
  359. for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
  360. {
  361. rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
  362. all_rejects.push_back(rejects);
  363. }
  364. index = 0;
  365. for (auto rej : all_rejects)
  366. {
  367. for (uint32_t i = 0; i < rej.size(); i++)
  368. {
  369. auto element = rej[i];
  370. auto expected_element = expected_reject[index][i];
  371. assert(element.index == expected_element.first && *element.code_points == expected_element.second);
  372. }
  373. index++;
  374. }
  375. for (auto &candidate : next_candidates)
  376. {
  377. delete[] candidate.code_points;
  378. candidate.code_points = nullptr;
  379. }
  380. delete grammar;
  381. return 0;
  382. }