test-llama-grammar.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #include "llama.h"
  5. #include "llama-grammar.h"
  6. #include <cassert>
  7. #include <stdexcept>
  8. int main()
  9. {
  10. llama_grammar_parser parsed_grammar;
  11. std::vector<std::pair<std::string, uint32_t>> expected = {
  12. {"expr", 2},
  13. {"expr_6", 6},
  14. {"expr_7", 7},
  15. {"ident", 8},
  16. {"ident_10", 10},
  17. {"num", 9},
  18. {"num_11", 11},
  19. {"root", 0},
  20. {"root_1", 1},
  21. {"root_5", 5},
  22. {"term", 4},
  23. {"ws", 3},
  24. {"ws_12", 12},
  25. };
  26. std::vector<std::vector<llama_grammar_element>> expected_rules = {
  27. {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
  28. {
  29. {LLAMA_GRETYPE_RULE_REF, 2},
  30. {LLAMA_GRETYPE_CHAR, 61},
  31. {LLAMA_GRETYPE_RULE_REF, 3},
  32. {LLAMA_GRETYPE_RULE_REF, 4},
  33. {LLAMA_GRETYPE_CHAR, 10},
  34. {LLAMA_GRETYPE_END, 0},
  35. },
  36. {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
  37. {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
  38. {
  39. {LLAMA_GRETYPE_RULE_REF, 8},
  40. {LLAMA_GRETYPE_ALT, 0},
  41. {LLAMA_GRETYPE_RULE_REF, 9},
  42. {LLAMA_GRETYPE_ALT, 0},
  43. {LLAMA_GRETYPE_CHAR, 40},
  44. {LLAMA_GRETYPE_RULE_REF, 3},
  45. {LLAMA_GRETYPE_RULE_REF, 2},
  46. {LLAMA_GRETYPE_CHAR, 41},
  47. {LLAMA_GRETYPE_RULE_REF, 3},
  48. {LLAMA_GRETYPE_END, 0},
  49. },
  50. {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
  51. {
  52. {LLAMA_GRETYPE_CHAR, 45},
  53. {LLAMA_GRETYPE_CHAR_ALT, 43},
  54. {LLAMA_GRETYPE_CHAR_ALT, 42},
  55. {LLAMA_GRETYPE_CHAR_ALT, 47},
  56. {LLAMA_GRETYPE_RULE_REF, 4},
  57. {LLAMA_GRETYPE_END, 0},
  58. },
  59. {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
  60. {
  61. {LLAMA_GRETYPE_CHAR, 97},
  62. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  63. {LLAMA_GRETYPE_RULE_REF, 10},
  64. {LLAMA_GRETYPE_RULE_REF, 3},
  65. {LLAMA_GRETYPE_END, 0},
  66. },
  67. {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
  68. {
  69. {LLAMA_GRETYPE_CHAR, 97},
  70. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
  71. {LLAMA_GRETYPE_CHAR_ALT, 48},
  72. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  73. {LLAMA_GRETYPE_CHAR_ALT, 95},
  74. {LLAMA_GRETYPE_RULE_REF, 10},
  75. {LLAMA_GRETYPE_ALT, 0},
  76. {LLAMA_GRETYPE_END, 0},
  77. },
  78. {
  79. {LLAMA_GRETYPE_CHAR, 48},
  80. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  81. {LLAMA_GRETYPE_RULE_REF, 11},
  82. {LLAMA_GRETYPE_ALT, 0},
  83. {LLAMA_GRETYPE_CHAR, 48},
  84. {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
  85. {LLAMA_GRETYPE_END, 0},
  86. },
  87. {
  88. {LLAMA_GRETYPE_CHAR, 32},
  89. {LLAMA_GRETYPE_CHAR_ALT, 9},
  90. {LLAMA_GRETYPE_CHAR_ALT, 10},
  91. {LLAMA_GRETYPE_RULE_REF, 12},
  92. {LLAMA_GRETYPE_ALT, 0},
  93. {LLAMA_GRETYPE_END, 0},
  94. },
  95. };
  96. for (auto pair : expected)
  97. {
  98. parsed_grammar.symbol_ids[pair.first] = pair.second;
  99. }
  100. for (auto rule : expected_rules)
  101. {
  102. parsed_grammar.rules.emplace_back();
  103. for (auto element : rule)
  104. {
  105. parsed_grammar.rules.back().push_back(element);
  106. }
  107. }
  108. std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
  109. llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  110. if (grammar == nullptr) {
  111. throw std::runtime_error("Failed to initialize llama_grammar");
  112. }
  113. std::vector<std::vector<llama_grammar_element>> expected_stacks = {
  114. {
  115. {LLAMA_GRETYPE_RULE_REF, 5},
  116. {LLAMA_GRETYPE_CHAR, 61},
  117. {LLAMA_GRETYPE_RULE_REF, 7},
  118. {LLAMA_GRETYPE_CHAR, 97},
  119. },
  120. {
  121. {LLAMA_GRETYPE_RULE_REF, 5},
  122. {LLAMA_GRETYPE_CHAR, 61},
  123. {LLAMA_GRETYPE_RULE_REF, 7},
  124. {LLAMA_GRETYPE_RULE_REF, 3},
  125. {LLAMA_GRETYPE_CHAR, 48},
  126. },
  127. {
  128. {LLAMA_GRETYPE_RULE_REF, 5},
  129. {LLAMA_GRETYPE_CHAR, 61},
  130. {LLAMA_GRETYPE_RULE_REF, 7},
  131. {LLAMA_GRETYPE_RULE_REF, 3},
  132. {LLAMA_GRETYPE_CHAR, 48},
  133. },
  134. {
  135. {LLAMA_GRETYPE_RULE_REF, 5},
  136. {LLAMA_GRETYPE_CHAR, 61},
  137. {LLAMA_GRETYPE_RULE_REF, 7},
  138. {LLAMA_GRETYPE_CHAR, 40},
  139. },
  140. {
  141. {LLAMA_GRETYPE_CHAR, 61},
  142. {LLAMA_GRETYPE_RULE_REF, 7},
  143. {LLAMA_GRETYPE_CHAR, 97},
  144. },
  145. {
  146. {LLAMA_GRETYPE_CHAR, 61},
  147. {LLAMA_GRETYPE_RULE_REF, 7},
  148. {LLAMA_GRETYPE_RULE_REF, 3},
  149. {LLAMA_GRETYPE_CHAR, 48},
  150. },
  151. {
  152. {LLAMA_GRETYPE_CHAR, 61},
  153. {LLAMA_GRETYPE_RULE_REF, 7},
  154. {LLAMA_GRETYPE_RULE_REF, 3},
  155. {LLAMA_GRETYPE_CHAR, 48},
  156. },
  157. {
  158. {LLAMA_GRETYPE_CHAR, 61},
  159. {LLAMA_GRETYPE_RULE_REF, 7},
  160. {LLAMA_GRETYPE_CHAR, 40},
  161. }};
  162. auto index = 0;
  163. for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
  164. {
  165. // compare stack to expected_stack
  166. for (uint32_t i = 0; i < stack.size(); i++)
  167. {
  168. const llama_grammar_element * element = stack[i];
  169. const llama_grammar_element & expected_element = expected_stacks[index][i];
  170. // pretty print error message before asserting
  171. if (expected_element.type != element->type || expected_element.value != element->value)
  172. {
  173. fprintf(stderr, "index: %d\n", index);
  174. fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
  175. fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
  176. fprintf(stderr, "expected_element != actual_element\n");
  177. }
  178. assert(expected_element.type == element->type && expected_element.value == element->value);
  179. }
  180. index++;
  181. }
  182. std::vector<llama_grammar_candidate> next_candidates;
  183. next_candidates.resize(24);
  184. for (size_t i = 0; i < 24; ++i)
  185. {
  186. uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
  187. cp[0] = 37 + i;
  188. cp[1] = 0;
  189. next_candidates[i] = {i, cp, {}};
  190. }
  191. std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
  192. {
  193. {0, 37},
  194. {1, 38},
  195. {2, 39},
  196. {3, 40},
  197. {4, 41},
  198. {5, 42},
  199. {6, 43},
  200. {7, 44},
  201. {8, 45},
  202. {9, 46},
  203. {10, 47},
  204. {11, 48},
  205. {12, 49},
  206. {13, 50},
  207. {14, 51},
  208. {15, 52},
  209. {16, 53},
  210. {17, 54},
  211. {18, 55},
  212. {19, 56},
  213. {20, 57},
  214. {21, 58},
  215. {22, 59},
  216. {23, 60},
  217. },
  218. {
  219. {0, 37},
  220. {1, 38},
  221. {2, 39},
  222. {3, 40},
  223. {4, 41},
  224. {5, 42},
  225. {6, 43},
  226. {7, 44},
  227. {8, 45},
  228. {9, 46},
  229. {10, 47},
  230. {21, 58},
  231. {22, 59},
  232. {23, 60},
  233. },
  234. {
  235. {0, 37},
  236. {1, 38},
  237. {2, 39},
  238. {3, 40},
  239. {4, 41},
  240. {5, 42},
  241. {6, 43},
  242. {7, 44},
  243. {8, 45},
  244. {9, 46},
  245. {10, 47},
  246. {21, 58},
  247. {22, 59},
  248. {23, 60},
  249. },
  250. {
  251. {0, 37},
  252. {1, 38},
  253. {2, 39},
  254. {4, 41},
  255. {5, 42},
  256. {6, 43},
  257. {7, 44},
  258. {8, 45},
  259. {9, 46},
  260. {10, 47},
  261. {11, 48},
  262. {12, 49},
  263. {13, 50},
  264. {14, 51},
  265. {15, 52},
  266. {16, 53},
  267. {17, 54},
  268. {18, 55},
  269. {19, 56},
  270. {20, 57},
  271. {21, 58},
  272. {22, 59},
  273. {23, 60},
  274. },
  275. {
  276. {0, 37},
  277. {1, 38},
  278. {2, 39},
  279. {3, 40},
  280. {4, 41},
  281. {5, 42},
  282. {6, 43},
  283. {7, 44},
  284. {8, 45},
  285. {9, 46},
  286. {10, 47},
  287. {11, 48},
  288. {12, 49},
  289. {13, 50},
  290. {14, 51},
  291. {15, 52},
  292. {16, 53},
  293. {17, 54},
  294. {18, 55},
  295. {19, 56},
  296. {20, 57},
  297. {21, 58},
  298. {22, 59},
  299. {23, 60},
  300. },
  301. {
  302. {0, 37},
  303. {1, 38},
  304. {2, 39},
  305. {3, 40},
  306. {4, 41},
  307. {5, 42},
  308. {6, 43},
  309. {7, 44},
  310. {8, 45},
  311. {9, 46},
  312. {10, 47},
  313. {21, 58},
  314. {22, 59},
  315. {23, 60},
  316. },
  317. {
  318. {0, 37},
  319. {1, 38},
  320. {2, 39},
  321. {3, 40},
  322. {4, 41},
  323. {5, 42},
  324. {6, 43},
  325. {7, 44},
  326. {8, 45},
  327. {9, 46},
  328. {10, 47},
  329. {21, 58},
  330. {22, 59},
  331. {23, 60},
  332. },
  333. {
  334. {0, 37},
  335. {1, 38},
  336. {2, 39},
  337. {4, 41},
  338. {5, 42},
  339. {6, 43},
  340. {7, 44},
  341. {8, 45},
  342. {9, 46},
  343. {10, 47},
  344. {11, 48},
  345. {12, 49},
  346. {13, 50},
  347. {14, 51},
  348. {15, 52},
  349. {16, 53},
  350. {17, 54},
  351. {18, 55},
  352. {19, 56},
  353. {20, 57},
  354. {21, 58},
  355. {22, 59},
  356. {23, 60},
  357. },
  358. };
  359. std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
  360. std::vector<std::vector<llama_grammar_candidate>> all_rejects;
  361. for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
  362. {
  363. rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
  364. all_rejects.push_back(rejects);
  365. }
  366. index = 0;
  367. for (auto rej : all_rejects)
  368. {
  369. for (uint32_t i = 0; i < rej.size(); i++)
  370. {
  371. auto element = rej[i];
  372. auto expected_element = expected_reject[index][i];
  373. assert(element.index == expected_element.first && *element.code_points == expected_element.second);
  374. }
  375. index++;
  376. }
  377. for (auto &candidate : next_candidates)
  378. {
  379. delete[] candidate.code_points;
  380. candidate.code_points = nullptr;
  381. }
  382. llama_grammar_free_impl(grammar);
  383. return 0;
  384. }