test-llama-grammar.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #define LLAMA_API_INTERNAL
  5. #include "llama.h"
  6. #include "grammar-parser.h"
  7. #include <cassert>
  8. #include <stdexcept>
  9. int main()
  10. {
  11. grammar_parser::parse_state parsed_grammar;
  12. std::vector<std::pair<std::string, uint32_t>> expected = {
  13. {"expr", 2},
  14. {"expr_6", 6},
  15. {"expr_7", 7},
  16. {"ident", 8},
  17. {"ident_10", 10},
  18. {"num", 9},
  19. {"num_11", 11},
  20. {"root", 0},
  21. {"root_1", 1},
  22. {"root_5", 5},
  23. {"term", 4},
  24. {"ws", 3},
  25. {"ws_12", 12},
  26. };
  27. std::vector<std::vector<llama_grammar_element>> expected_rules = {
  28. {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
  29. {
  30. {LLAMA_GRETYPE_RULE_REF, 2},
  31. {LLAMA_GRETYPE_CHAR, 61},
  32. {LLAMA_GRETYPE_RULE_REF, 3},
  33. {LLAMA_GRETYPE_RULE_REF, 4},
  34. {LLAMA_GRETYPE_CHAR, 10},
  35. {LLAMA_GRETYPE_END, 0},
  36. },
  37. {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
  38. {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
  39. {
  40. {LLAMA_GRETYPE_RULE_REF, 8},
  41. {LLAMA_GRETYPE_ALT, 0},
  42. {LLAMA_GRETYPE_RULE_REF, 9},
  43. {LLAMA_GRETYPE_ALT, 0},
  44. {LLAMA_GRETYPE_CHAR, 40},
  45. {LLAMA_GRETYPE_RULE_REF, 3},
  46. {LLAMA_GRETYPE_RULE_REF, 2},
  47. {LLAMA_GRETYPE_CHAR, 41},
  48. {LLAMA_GRETYPE_RULE_REF, 3},
  49. {LLAMA_GRETYPE_END, 0},
  50. },
  51. {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
  52. {
  53. {LLAMA_GRETYPE_CHAR, 45},
  54. {LLAMA_GRETYPE_CHAR_ALT, 43},
  55. {LLAMA_GRETYPE_CHAR_ALT, 42},
  56. {LLAMA_GRETYPE_CHAR_ALT, 47},
  57. {LLAMA_GRETYPE_RULE_REF, 4},
  58. {LLAMA_GRETYPE_END, 0},
  59. },
  60. {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
  61. {
  62. {LLAMA_GRETYPE_CHAR, 97},
  63. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  64. {LLAMA_GRETYPE_RULE_REF, 10},
  65. {LLAMA_GRETYPE_RULE_REF, 3},
  66. {LLAMA_GRETYPE_END, 0},
  67. },
  68. {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
  69. {
  70. {LLAMA_GRETYPE_CHAR, 97},
  71. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  72. {LLAMA_GRETYPE_CHAR_ALT, 48},
  73. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  74. {LLAMA_GRETYPE_CHAR_ALT, 95},
  75. {LLAMA_GRETYPE_RULE_REF, 10},
  76. {LLAMA_GRETYPE_ALT, 0},
  77. {LLAMA_GRETYPE_END, 0},
  78. },
  79. {
  80. {LLAMA_GRETYPE_CHAR, 48},
  81. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  82. {LLAMA_GRETYPE_RULE_REF, 11},
  83. {LLAMA_GRETYPE_ALT, 0},
  84. {LLAMA_GRETYPE_CHAR, 48},
  85. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  86. {LLAMA_GRETYPE_END, 0},
  87. },
  88. {
  89. {LLAMA_GRETYPE_CHAR, 32},
  90. {LLAMA_GRETYPE_CHAR_ALT, 9},
  91. {LLAMA_GRETYPE_CHAR_ALT, 10},
  92. {LLAMA_GRETYPE_RULE_REF, 12},
  93. {LLAMA_GRETYPE_ALT, 0},
  94. {LLAMA_GRETYPE_END, 0},
  95. },
  96. };
  97. for (auto pair : expected)
  98. {
  99. parsed_grammar.symbol_ids[pair.first] = pair.second;
  100. }
  101. for (auto rule : expected_rules)
  102. {
  103. parsed_grammar.rules.emplace_back();
  104. for (auto element : rule)
  105. {
  106. parsed_grammar.rules.back().push_back(element);
  107. }
  108. }
  109. llama_grammar * grammar = NULL;
  110. std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
  111. grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  112. if (grammar == nullptr)
  113. {
  114. throw std::runtime_error("Failed to initialize llama_grammar");
  115. }
  116. std::vector<std::vector<llama_grammar_element>> expected_stacks = {
  117. {
  118. {LLAMA_GRETYPE_RULE_REF, 5},
  119. {LLAMA_GRETYPE_CHAR, 61},
  120. {LLAMA_GRETYPE_RULE_REF, 7},
  121. {LLAMA_GRETYPE_CHAR, 97},
  122. },
  123. {
  124. {LLAMA_GRETYPE_RULE_REF, 5},
  125. {LLAMA_GRETYPE_CHAR, 61},
  126. {LLAMA_GRETYPE_RULE_REF, 7},
  127. {LLAMA_GRETYPE_RULE_REF, 3},
  128. {LLAMA_GRETYPE_CHAR, 48},
  129. },
  130. {
  131. {LLAMA_GRETYPE_RULE_REF, 5},
  132. {LLAMA_GRETYPE_CHAR, 61},
  133. {LLAMA_GRETYPE_RULE_REF, 7},
  134. {LLAMA_GRETYPE_RULE_REF, 3},
  135. {LLAMA_GRETYPE_CHAR, 48},
  136. },
  137. {
  138. {LLAMA_GRETYPE_RULE_REF, 5},
  139. {LLAMA_GRETYPE_CHAR, 61},
  140. {LLAMA_GRETYPE_RULE_REF, 7},
  141. {LLAMA_GRETYPE_CHAR, 40},
  142. },
  143. {
  144. {LLAMA_GRETYPE_CHAR, 61},
  145. {LLAMA_GRETYPE_RULE_REF, 7},
  146. {LLAMA_GRETYPE_CHAR, 97},
  147. },
  148. {
  149. {LLAMA_GRETYPE_CHAR, 61},
  150. {LLAMA_GRETYPE_RULE_REF, 7},
  151. {LLAMA_GRETYPE_RULE_REF, 3},
  152. {LLAMA_GRETYPE_CHAR, 48},
  153. },
  154. {
  155. {LLAMA_GRETYPE_CHAR, 61},
  156. {LLAMA_GRETYPE_RULE_REF, 7},
  157. {LLAMA_GRETYPE_RULE_REF, 3},
  158. {LLAMA_GRETYPE_CHAR, 48},
  159. },
  160. {
  161. {LLAMA_GRETYPE_CHAR, 61},
  162. {LLAMA_GRETYPE_RULE_REF, 7},
  163. {LLAMA_GRETYPE_CHAR, 40},
  164. }};
  165. auto index = 0;
  166. for (auto stack : llama_grammar_get_stacks(grammar))
  167. {
  168. // compare stack to expected_stack
  169. for (uint32_t i = 0; i < stack.size(); i++)
  170. {
  171. auto element = stack[i];
  172. auto expected_element = expected_stacks[index][i];
  173. // pretty print error message before asserting
  174. if (expected_element.type != element->type || expected_element.value != element->value)
  175. {
  176. fprintf(stderr, "index: %d\n", index);
  177. fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
  178. fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
  179. fprintf(stderr, "expected_element != actual_element\n");
  180. }
  181. assert(expected_element.type == element->type && expected_element.value == element->value);
  182. }
  183. index++;
  184. }
  185. std::vector<llama_grammar_candidate> next_candidates;
  186. next_candidates.resize(24);
  187. for (size_t i = 0; i < 24; ++i)
  188. {
  189. uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
  190. cp[0] = 37 + i;
  191. cp[1] = 0;
  192. next_candidates[i] = {i, cp, {}};
  193. }
  194. std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
  195. {
  196. {0, 37},
  197. {1, 38},
  198. {2, 39},
  199. {3, 40},
  200. {4, 41},
  201. {5, 42},
  202. {6, 43},
  203. {7, 44},
  204. {8, 45},
  205. {9, 46},
  206. {10, 47},
  207. {11, 48},
  208. {12, 49},
  209. {13, 50},
  210. {14, 51},
  211. {15, 52},
  212. {16, 53},
  213. {17, 54},
  214. {18, 55},
  215. {19, 56},
  216. {20, 57},
  217. {21, 58},
  218. {22, 59},
  219. {23, 60},
  220. },
  221. {
  222. {0, 37},
  223. {1, 38},
  224. {2, 39},
  225. {3, 40},
  226. {4, 41},
  227. {5, 42},
  228. {6, 43},
  229. {7, 44},
  230. {8, 45},
  231. {9, 46},
  232. {10, 47},
  233. {21, 58},
  234. {22, 59},
  235. {23, 60},
  236. },
  237. {
  238. {0, 37},
  239. {1, 38},
  240. {2, 39},
  241. {3, 40},
  242. {4, 41},
  243. {5, 42},
  244. {6, 43},
  245. {7, 44},
  246. {8, 45},
  247. {9, 46},
  248. {10, 47},
  249. {21, 58},
  250. {22, 59},
  251. {23, 60},
  252. },
  253. {
  254. {0, 37},
  255. {1, 38},
  256. {2, 39},
  257. {4, 41},
  258. {5, 42},
  259. {6, 43},
  260. {7, 44},
  261. {8, 45},
  262. {9, 46},
  263. {10, 47},
  264. {11, 48},
  265. {12, 49},
  266. {13, 50},
  267. {14, 51},
  268. {15, 52},
  269. {16, 53},
  270. {17, 54},
  271. {18, 55},
  272. {19, 56},
  273. {20, 57},
  274. {21, 58},
  275. {22, 59},
  276. {23, 60},
  277. },
  278. {
  279. {0, 37},
  280. {1, 38},
  281. {2, 39},
  282. {3, 40},
  283. {4, 41},
  284. {5, 42},
  285. {6, 43},
  286. {7, 44},
  287. {8, 45},
  288. {9, 46},
  289. {10, 47},
  290. {11, 48},
  291. {12, 49},
  292. {13, 50},
  293. {14, 51},
  294. {15, 52},
  295. {16, 53},
  296. {17, 54},
  297. {18, 55},
  298. {19, 56},
  299. {20, 57},
  300. {21, 58},
  301. {22, 59},
  302. {23, 60},
  303. },
  304. {
  305. {0, 37},
  306. {1, 38},
  307. {2, 39},
  308. {3, 40},
  309. {4, 41},
  310. {5, 42},
  311. {6, 43},
  312. {7, 44},
  313. {8, 45},
  314. {9, 46},
  315. {10, 47},
  316. {21, 58},
  317. {22, 59},
  318. {23, 60},
  319. },
  320. {
  321. {0, 37},
  322. {1, 38},
  323. {2, 39},
  324. {3, 40},
  325. {4, 41},
  326. {5, 42},
  327. {6, 43},
  328. {7, 44},
  329. {8, 45},
  330. {9, 46},
  331. {10, 47},
  332. {21, 58},
  333. {22, 59},
  334. {23, 60},
  335. },
  336. {
  337. {0, 37},
  338. {1, 38},
  339. {2, 39},
  340. {4, 41},
  341. {5, 42},
  342. {6, 43},
  343. {7, 44},
  344. {8, 45},
  345. {9, 46},
  346. {10, 47},
  347. {11, 48},
  348. {12, 49},
  349. {13, 50},
  350. {14, 51},
  351. {15, 52},
  352. {16, 53},
  353. {17, 54},
  354. {18, 55},
  355. {19, 56},
  356. {20, 57},
  357. {21, 58},
  358. {22, 59},
  359. {23, 60},
  360. },
  361. };
  362. std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
  363. std::vector<std::vector<llama_grammar_candidate>> all_rejects;
  364. for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
  365. {
  366. rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
  367. all_rejects.push_back(rejects);
  368. }
  369. index = 0;
  370. for (auto rej : all_rejects)
  371. {
  372. for (uint32_t i = 0; i < rej.size(); i++)
  373. {
  374. auto element = rej[i];
  375. auto expected_element = expected_reject[index][i];
  376. assert(element.index == expected_element.first && *element.code_points == expected_element.second);
  377. }
  378. index++;
  379. }
  380. for (auto &candidate : next_candidates)
  381. {
  382. delete[] candidate.code_points;
  383. candidate.code_points = nullptr;
  384. }
  385. llama_grammar_free(grammar);
  386. return 0;
  387. }