test-llama-grammar.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #include "llama.cpp" // TODO: not great
  5. #include "grammar-parser.h"
  6. #include <cassert>
  7. int main()
  8. {
  9. grammar_parser::parse_state parsed_grammar;
  10. std::vector<std::pair<std::string, uint32_t>> expected = {
  11. {"expr", 2},
  12. {"expr_6", 6},
  13. {"expr_7", 7},
  14. {"ident", 8},
  15. {"ident_10", 10},
  16. {"num", 9},
  17. {"num_11", 11},
  18. {"root", 0},
  19. {"root_1", 1},
  20. {"root_5", 5},
  21. {"term", 4},
  22. {"ws", 3},
  23. {"ws_12", 12},
  24. };
  25. std::vector<std::vector<llama_grammar_element>> expected_rules = {
  26. {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
  27. {
  28. {LLAMA_GRETYPE_RULE_REF, 2},
  29. {LLAMA_GRETYPE_CHAR, 61},
  30. {LLAMA_GRETYPE_RULE_REF, 3},
  31. {LLAMA_GRETYPE_RULE_REF, 4},
  32. {LLAMA_GRETYPE_CHAR, 10},
  33. {LLAMA_GRETYPE_END, 0},
  34. },
  35. {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
  36. {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
  37. {
  38. {LLAMA_GRETYPE_RULE_REF, 8},
  39. {LLAMA_GRETYPE_ALT, 0},
  40. {LLAMA_GRETYPE_RULE_REF, 9},
  41. {LLAMA_GRETYPE_ALT, 0},
  42. {LLAMA_GRETYPE_CHAR, 40},
  43. {LLAMA_GRETYPE_RULE_REF, 3},
  44. {LLAMA_GRETYPE_RULE_REF, 2},
  45. {LLAMA_GRETYPE_CHAR, 41},
  46. {LLAMA_GRETYPE_RULE_REF, 3},
  47. {LLAMA_GRETYPE_END, 0},
  48. },
  49. {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
  50. {
  51. {LLAMA_GRETYPE_CHAR, 45},
  52. {LLAMA_GRETYPE_CHAR_ALT, 43},
  53. {LLAMA_GRETYPE_CHAR_ALT, 42},
  54. {LLAMA_GRETYPE_CHAR_ALT, 47},
  55. {LLAMA_GRETYPE_RULE_REF, 4},
  56. {LLAMA_GRETYPE_END, 0},
  57. },
  58. {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
  59. {
  60. {LLAMA_GRETYPE_CHAR, 97},
  61. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  62. {LLAMA_GRETYPE_RULE_REF, 10},
  63. {LLAMA_GRETYPE_RULE_REF, 3},
  64. {LLAMA_GRETYPE_END, 0},
  65. },
  66. {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
  67. {
  68. {LLAMA_GRETYPE_CHAR, 97},
  69. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  70. {LLAMA_GRETYPE_CHAR_ALT, 48},
  71. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  72. {LLAMA_GRETYPE_CHAR_ALT, 95},
  73. {LLAMA_GRETYPE_RULE_REF, 10},
  74. {LLAMA_GRETYPE_ALT, 0},
  75. {LLAMA_GRETYPE_END, 0},
  76. },
  77. {
  78. {LLAMA_GRETYPE_CHAR, 48},
  79. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  80. {LLAMA_GRETYPE_RULE_REF, 11},
  81. {LLAMA_GRETYPE_ALT, 0},
  82. {LLAMA_GRETYPE_CHAR, 48},
  83. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  84. {LLAMA_GRETYPE_END, 0},
  85. },
  86. {
  87. {LLAMA_GRETYPE_CHAR, 32},
  88. {LLAMA_GRETYPE_CHAR_ALT, 9},
  89. {LLAMA_GRETYPE_CHAR_ALT, 10},
  90. {LLAMA_GRETYPE_RULE_REF, 12},
  91. {LLAMA_GRETYPE_ALT, 0},
  92. {LLAMA_GRETYPE_END, 0},
  93. },
  94. };
  95. for (auto pair : expected)
  96. {
  97. parsed_grammar.symbol_ids[pair.first] = pair.second;
  98. }
  99. for (auto rule : expected_rules)
  100. {
  101. parsed_grammar.rules.emplace_back();
  102. for (auto element : rule)
  103. {
  104. parsed_grammar.rules.back().push_back(element);
  105. }
  106. }
  107. llama_grammar *grammar = NULL;
  108. std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
  109. grammar = llama_grammar_init(
  110. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  111. if (grammar == nullptr)
  112. {
  113. throw std::runtime_error("Failed to initialize llama_grammar");
  114. }
  115. std::vector<std::vector<llama_grammar_element>> expected_stacks = {
  116. {
  117. {LLAMA_GRETYPE_RULE_REF, 5},
  118. {LLAMA_GRETYPE_CHAR, 61},
  119. {LLAMA_GRETYPE_RULE_REF, 7},
  120. {LLAMA_GRETYPE_CHAR, 97},
  121. },
  122. {
  123. {LLAMA_GRETYPE_RULE_REF, 5},
  124. {LLAMA_GRETYPE_CHAR, 61},
  125. {LLAMA_GRETYPE_RULE_REF, 7},
  126. {LLAMA_GRETYPE_RULE_REF, 3},
  127. {LLAMA_GRETYPE_CHAR, 48},
  128. },
  129. {
  130. {LLAMA_GRETYPE_RULE_REF, 5},
  131. {LLAMA_GRETYPE_CHAR, 61},
  132. {LLAMA_GRETYPE_RULE_REF, 7},
  133. {LLAMA_GRETYPE_RULE_REF, 3},
  134. {LLAMA_GRETYPE_CHAR, 48},
  135. },
  136. {
  137. {LLAMA_GRETYPE_RULE_REF, 5},
  138. {LLAMA_GRETYPE_CHAR, 61},
  139. {LLAMA_GRETYPE_RULE_REF, 7},
  140. {LLAMA_GRETYPE_CHAR, 40},
  141. },
  142. {
  143. {LLAMA_GRETYPE_CHAR, 61},
  144. {LLAMA_GRETYPE_RULE_REF, 7},
  145. {LLAMA_GRETYPE_CHAR, 97},
  146. },
  147. {
  148. {LLAMA_GRETYPE_CHAR, 61},
  149. {LLAMA_GRETYPE_RULE_REF, 7},
  150. {LLAMA_GRETYPE_RULE_REF, 3},
  151. {LLAMA_GRETYPE_CHAR, 48},
  152. },
  153. {
  154. {LLAMA_GRETYPE_CHAR, 61},
  155. {LLAMA_GRETYPE_RULE_REF, 7},
  156. {LLAMA_GRETYPE_RULE_REF, 3},
  157. {LLAMA_GRETYPE_CHAR, 48},
  158. },
  159. {
  160. {LLAMA_GRETYPE_CHAR, 61},
  161. {LLAMA_GRETYPE_RULE_REF, 7},
  162. {LLAMA_GRETYPE_CHAR, 40},
  163. }};
  164. auto index = 0;
  165. for (auto stack : grammar->stacks)
  166. {
  167. // compare stack to expected_stack
  168. for (uint32_t i = 0; i < stack.size(); i++)
  169. {
  170. auto element = stack[i];
  171. auto expected_element = expected_stacks[index][i];
  172. // pretty print error message before asserting
  173. if (expected_element.type != element->type || expected_element.value != element->value)
  174. {
  175. fprintf(stderr, "index: %d\n", index);
  176. fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
  177. fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
  178. fprintf(stderr, "expected_element != actual_element\n");
  179. }
  180. assert(expected_element.type == element->type && expected_element.value == element->value);
  181. }
  182. index++;
  183. }
  184. std::vector<llama_grammar_candidate> next_candidates;
  185. next_candidates.resize(24);
  186. for (size_t i = 0; i < 24; ++i)
  187. {
  188. uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
  189. cp[0] = 37 + i;
  190. cp[1] = 0;
  191. next_candidates[i] = {i, cp, {}};
  192. }
  193. std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
  194. {
  195. {0, 37},
  196. {1, 38},
  197. {2, 39},
  198. {3, 40},
  199. {4, 41},
  200. {5, 42},
  201. {6, 43},
  202. {7, 44},
  203. {8, 45},
  204. {9, 46},
  205. {10, 47},
  206. {11, 48},
  207. {12, 49},
  208. {13, 50},
  209. {14, 51},
  210. {15, 52},
  211. {16, 53},
  212. {17, 54},
  213. {18, 55},
  214. {19, 56},
  215. {20, 57},
  216. {21, 58},
  217. {22, 59},
  218. {23, 60},
  219. },
  220. {
  221. {0, 37},
  222. {1, 38},
  223. {2, 39},
  224. {3, 40},
  225. {4, 41},
  226. {5, 42},
  227. {6, 43},
  228. {7, 44},
  229. {8, 45},
  230. {9, 46},
  231. {10, 47},
  232. {21, 58},
  233. {22, 59},
  234. {23, 60},
  235. },
  236. {
  237. {0, 37},
  238. {1, 38},
  239. {2, 39},
  240. {3, 40},
  241. {4, 41},
  242. {5, 42},
  243. {6, 43},
  244. {7, 44},
  245. {8, 45},
  246. {9, 46},
  247. {10, 47},
  248. {21, 58},
  249. {22, 59},
  250. {23, 60},
  251. },
  252. {
  253. {0, 37},
  254. {1, 38},
  255. {2, 39},
  256. {4, 41},
  257. {5, 42},
  258. {6, 43},
  259. {7, 44},
  260. {8, 45},
  261. {9, 46},
  262. {10, 47},
  263. {11, 48},
  264. {12, 49},
  265. {13, 50},
  266. {14, 51},
  267. {15, 52},
  268. {16, 53},
  269. {17, 54},
  270. {18, 55},
  271. {19, 56},
  272. {20, 57},
  273. {21, 58},
  274. {22, 59},
  275. {23, 60},
  276. },
  277. {
  278. {0, 37},
  279. {1, 38},
  280. {2, 39},
  281. {3, 40},
  282. {4, 41},
  283. {5, 42},
  284. {6, 43},
  285. {7, 44},
  286. {8, 45},
  287. {9, 46},
  288. {10, 47},
  289. {11, 48},
  290. {12, 49},
  291. {13, 50},
  292. {14, 51},
  293. {15, 52},
  294. {16, 53},
  295. {17, 54},
  296. {18, 55},
  297. {19, 56},
  298. {20, 57},
  299. {21, 58},
  300. {22, 59},
  301. {23, 60},
  302. },
  303. {
  304. {0, 37},
  305. {1, 38},
  306. {2, 39},
  307. {3, 40},
  308. {4, 41},
  309. {5, 42},
  310. {6, 43},
  311. {7, 44},
  312. {8, 45},
  313. {9, 46},
  314. {10, 47},
  315. {21, 58},
  316. {22, 59},
  317. {23, 60},
  318. },
  319. {
  320. {0, 37},
  321. {1, 38},
  322. {2, 39},
  323. {3, 40},
  324. {4, 41},
  325. {5, 42},
  326. {6, 43},
  327. {7, 44},
  328. {8, 45},
  329. {9, 46},
  330. {10, 47},
  331. {21, 58},
  332. {22, 59},
  333. {23, 60},
  334. },
  335. {
  336. {0, 37},
  337. {1, 38},
  338. {2, 39},
  339. {4, 41},
  340. {5, 42},
  341. {6, 43},
  342. {7, 44},
  343. {8, 45},
  344. {9, 46},
  345. {10, 47},
  346. {11, 48},
  347. {12, 49},
  348. {13, 50},
  349. {14, 51},
  350. {15, 52},
  351. {16, 53},
  352. {17, 54},
  353. {18, 55},
  354. {19, 56},
  355. {20, 57},
  356. {21, 58},
  357. {22, 59},
  358. {23, 60},
  359. },
  360. };
  361. std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
  362. std::vector<std::vector<llama_grammar_candidate>> all_rejects;
  363. for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
  364. {
  365. rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
  366. all_rejects.push_back(rejects);
  367. }
  368. index = 0;
  369. for (auto rej : all_rejects)
  370. {
  371. for (uint32_t i = 0; i < rej.size(); i++)
  372. {
  373. auto element = rej[i];
  374. auto expected_element = expected_reject[index][i];
  375. assert(element.index == expected_element.first && *element.code_points == expected_element.second);
  376. }
  377. index++;
  378. }
  379. for (auto &candidate : next_candidates)
  380. {
  381. delete[] candidate.code_points;
  382. candidate.code_points = nullptr;
  383. }
  384. delete grammar;
  385. return 0;
  386. }