test-grammar-integration.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #define LLAMA_API_INTERNAL
  5. #include "ggml.h"
  6. #include "llama.h"
  7. #include "grammar-parser.h"
  8. #include "unicode.h"
  9. #include <cassert>
  10. #include <string>
  11. #include <vector>
  12. static llama_grammar* build_grammar(const std::string & grammar_str) {
  13. auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  14. // Ensure we parsed correctly
  15. assert(!parsed_grammar.rules.empty());
  16. // Ensure we have a root node
  17. assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
  18. std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
  19. llama_grammar* grammar = llama_grammar_init(
  20. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  21. return grammar;
  22. }
  23. static bool test_build_grammar_fails(const std::string & grammar_str) {
  24. fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
  25. bool grammar_fails = false;
  26. try {
  27. build_grammar(grammar_str);
  28. fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
  29. } catch (const std::exception & err) {
  30. grammar_fails = true;
  31. fprintf(stdout, " ✅︎\n");
  32. }
  33. return grammar_fails;
  34. }
  35. static bool match_string(const std::string & input, llama_grammar* grammar) {
  36. auto decoded = decode_utf8(input, {});
  37. const auto & code_points = decoded.first;
  38. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  39. auto prev_stacks = grammar->stacks;
  40. llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
  41. if (grammar->stacks.empty()) {
  42. // no stacks means that the grammar failed to match at this point
  43. return false;
  44. }
  45. }
  46. for (const auto & stack : grammar->stacks) {
  47. if (stack.empty()) {
  48. // An empty stack means that the grammar has been completed
  49. return true;
  50. }
  51. }
  52. return false;
  53. }
  54. static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
  55. fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
  56. fflush(stderr);
  57. auto grammar = build_grammar(grammar_str);
  58. // Save the original grammar stacks so that we can reset after every new string we want to test
  59. auto original_stacks = grammar->stacks;
  60. fprintf(stderr, " 🔵 Valid strings:\n");
  61. // Passing strings
  62. for (const auto & test_string : passing_strings) {
  63. fprintf(stderr, " \"%s\" ", test_string.c_str());
  64. fflush(stderr);
  65. bool matched = match_string(test_string, grammar);
  66. if (!matched) {
  67. fprintf(stderr, "❌ (failed to match)\n");
  68. } else {
  69. fprintf(stdout, "✅︎\n");
  70. }
  71. assert(matched);
  72. // Reset the grammar stacks
  73. grammar->stacks = original_stacks;
  74. }
  75. fprintf(stderr, " 🟠 Invalid strings:\n");
  76. // Failing strings
  77. for (const auto & test_string : failing_strings) {
  78. fprintf(stderr, " \"%s\" ", test_string.c_str());
  79. fflush(stderr);
  80. bool matched = match_string(test_string, grammar);
  81. if (matched) {
  82. fprintf(stderr, "❌ (incorrectly matched)\n");
  83. } else {
  84. fprintf(stdout, "✅︎\n");
  85. }
  86. assert(!matched);
  87. // Reset the grammar stacks
  88. grammar->stacks = original_stacks;
  89. }
  90. // Clean up allocated memory
  91. llama_grammar_free(grammar);
  92. }
  93. static void test_simple_grammar() {
  94. // Test case for a simple grammar
  95. test_grammar(
  96. "simple grammar",
  97. R"""(
  98. root ::= expr
  99. expr ::= term ("+" term)*
  100. term ::= number
  101. number ::= [0-9]+)""",
  102. // Passing strings
  103. {
  104. "42",
  105. "1+2+3+4+5",
  106. "123+456",
  107. },
  108. // Failing strings
  109. {
  110. "+",
  111. "/ 3",
  112. "1+2+3+4+5+",
  113. "12a45",
  114. }
  115. );
  116. }
  117. static void test_complex_grammar() {
  118. // Test case for a more complex grammar, with both failure strings and success strings
  119. test_grammar(
  120. "medium complexity grammar",
  121. // Grammar
  122. R"""(
  123. root ::= expression
  124. expression ::= term ws (("+"|"-") ws term)*
  125. term ::= factor ws (("*"|"/") ws factor)*
  126. factor ::= number | variable | "(" expression ")" | function-call
  127. number ::= [0-9]+
  128. variable ::= [a-zA-Z_][a-zA-Z0-9_]*
  129. function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
  130. ws ::= [ \t\n\r]?)""",
  131. // Passing strings
  132. {
  133. "42",
  134. "1*2*3*4*5",
  135. "x",
  136. "x+10",
  137. "x1+y2",
  138. "(a+b)*(c-d)",
  139. "func()",
  140. "func(x,y+2)",
  141. "a*(b+c)-d/e",
  142. "f(g(x),h(y,z))",
  143. "x + 10",
  144. "x1 + y2",
  145. "(a + b) * (c - d)",
  146. "func()",
  147. "func(x, y + 2)",
  148. "a * (b + c) - d / e",
  149. "f(g(x), h(y, z))",
  150. "123+456",
  151. "123*456*789-123/456+789*123",
  152. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
  153. },
  154. // Failing strings
  155. {
  156. "+",
  157. "/ 3x",
  158. "x + + y",
  159. "a * / b",
  160. "func(,)",
  161. "func(x y)",
  162. "(a + b",
  163. "x + y)",
  164. "a + b * (c - d",
  165. "42 +",
  166. "x +",
  167. "x + 10 +",
  168. "(a + b) * (c - d",
  169. "func(",
  170. "func(x, y + 2",
  171. "a * (b + c) - d /",
  172. "f(g(x), h(y, z)",
  173. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
  174. }
  175. );
  176. }
  177. static void test_quantifiers() {
  178. // A collection of tests to exercise * + and ? quantifiers
  179. test_grammar(
  180. "* quantifier",
  181. // Grammar
  182. R"""(root ::= "a"*)""",
  183. // Passing strings
  184. {
  185. "",
  186. "a",
  187. "aaaaa",
  188. "aaaaaaaaaaaaaaaaaa",
  189. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  190. },
  191. // Failing strings
  192. {
  193. "b",
  194. "ab",
  195. "aab",
  196. "ba",
  197. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
  198. }
  199. );
  200. test_grammar(
  201. "+ quantifier",
  202. // Grammar
  203. R"""(root ::= "a"+)""",
  204. // Passing strings
  205. {
  206. "a",
  207. "aaaaa",
  208. "aaaaaaaaaaaaaaaaaa",
  209. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  210. },
  211. // Failing strings
  212. {
  213. "",
  214. "b",
  215. "ab",
  216. "aab",
  217. "ba",
  218. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
  219. }
  220. );
  221. test_grammar(
  222. "? quantifier",
  223. // Grammar
  224. R"""(root ::= "a"?)""",
  225. // Passing strings
  226. {
  227. "",
  228. "a"
  229. },
  230. // Failing strings
  231. {
  232. "b",
  233. "ab",
  234. "aa",
  235. "ba",
  236. }
  237. );
  238. test_grammar(
  239. "mixed quantifiers",
  240. // Grammar
  241. R"""(
  242. root ::= cons+ vowel* cons? (vowel cons)*
  243. vowel ::= [aeiouy]
  244. cons ::= [bcdfghjklmnpqrstvwxyz]
  245. )""",
  246. // Passing strings
  247. {
  248. "yes",
  249. "no",
  250. "noyes",
  251. "crwth",
  252. "four",
  253. "bryyyy",
  254. },
  255. // Failing strings
  256. {
  257. "yess",
  258. "yesno",
  259. "forty",
  260. "catyyy",
  261. }
  262. );
  263. test_grammar(
  264. "simple exact repetition",
  265. // Grammar
  266. R"""(
  267. root ::= [ab]{4}
  268. )""",
  269. // Passing strings
  270. {
  271. "aaaa",
  272. "bbbb",
  273. "abab",
  274. },
  275. // Failing strings
  276. {
  277. "a",
  278. "b",
  279. "aaaaa",
  280. }
  281. );
  282. test_grammar(
  283. "simple min repetition",
  284. // Grammar
  285. R"""(
  286. root ::= [ab]{4,}
  287. )""",
  288. // Passing strings
  289. {
  290. "aaaa",
  291. "aaaaab",
  292. "bbbb",
  293. "ababab",
  294. },
  295. // Failing strings
  296. {
  297. "",
  298. "aba",
  299. }
  300. );
  301. test_grammar(
  302. "simple max repetition",
  303. // Grammar
  304. R"""(
  305. root ::= [ab]{0,4}
  306. )""",
  307. // Passing strings
  308. {
  309. "",
  310. "a",
  311. "aa",
  312. "aaa",
  313. "aaab",
  314. },
  315. // Failing strings
  316. {
  317. "aaaaa",
  318. }
  319. );
  320. test_grammar(
  321. "min / max repetition",
  322. // Grammar
  323. R"""(
  324. root ::= ("0x" [A-F0-9]{2} " "?){3,5}
  325. )""",
  326. // Passing strings
  327. {
  328. "0xFF 0x12 0xAB",
  329. "0xFF 0x12 0xAB 0x00 0x00",
  330. },
  331. // Failing strings
  332. {
  333. "",
  334. "0xFF",
  335. "0xFF 0x12",
  336. "0xFF 0x12 0xAB 0x00 0x00 0x00",
  337. }
  338. );
  339. }
  340. static void test_failure_missing_root() {
  341. fprintf(stderr, "⚫ Testing missing root node:\n");
  342. // Test case for a grammar that is missing a root rule
  343. const std::string grammar_str = R"""(rot ::= expr
  344. expr ::= term ("+" term)*
  345. term ::= number
  346. number ::= [0-9]+)""";
  347. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  348. // Ensure we parsed correctly
  349. assert(!parsed_grammar.rules.empty());
  350. // Ensure we do NOT have a root node
  351. assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
  352. fprintf(stderr, " ✅︎ Passed\n");
  353. }
  354. static void test_failure_missing_reference() {
  355. fprintf(stderr, "⚫ Testing missing reference node:\n");
  356. // Test case for a grammar that is missing a referenced rule
  357. const std::string grammar_str =
  358. R"""(root ::= expr
  359. expr ::= term ("+" term)*
  360. term ::= numero
  361. number ::= [0-9]+)""";
  362. fprintf(stderr, " Expected error: ");
  363. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  364. // Ensure we did NOT parsed correctly
  365. assert(parsed_grammar.rules.empty());
  366. fprintf(stderr, " End of expected error.\n");
  367. fprintf(stderr, " ✅︎ Passed\n");
  368. }
  369. static void test_failure_left_recursion() {
  370. fprintf(stderr, "⚫ Testing left recursion detection:\n");
  371. // Test simple left recursion detection
  372. const std::string simple_str = R"""(root ::= "a" | root "a")""";
  373. assert(test_build_grammar_fails(simple_str));
  374. // Test more complicated left recursion detection
  375. const std::string medium_str = R"""(
  376. root ::= asdf
  377. asdf ::= "a" | asdf "a"
  378. )""";
  379. assert(test_build_grammar_fails(medium_str));
  380. // Test even more complicated left recursion detection
  381. const std::string hard_str = R"""(
  382. root ::= asdf
  383. asdf ::= "a" | foo "b"
  384. foo ::= "c" | asdf "d" | "e")""";
  385. assert(test_build_grammar_fails(hard_str));
  386. // Test yet even more complicated left recursion detection
  387. const std::string hardest_str = R"""(
  388. root ::= asdf
  389. asdf ::= "a" | foo "b"
  390. foo ::= "c" | empty asdf "d" | "e"
  391. empty ::= "blah" | )""";
  392. assert(test_build_grammar_fails(hardest_str));
  393. fprintf(stderr, " ✅︎ Passed\n");
  394. }
  395. int main() {
  396. fprintf(stdout, "Running grammar integration tests...\n");
  397. test_simple_grammar();
  398. test_complex_grammar();
  399. test_quantifiers();
  400. test_failure_missing_root();
  401. test_failure_missing_reference();
  402. test_failure_left_recursion();
  403. fprintf(stdout, "All tests passed.\n");
  404. return 0;
  405. }