test-grammar-integration.cpp 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #define LLAMA_API_INTERNAL
  5. #include "ggml.h"
  6. #include "llama.h"
  7. #include "grammar-parser.h"
  8. #include "unicode.h"
  9. #include <cassert>
  10. #include <string>
  11. #include <vector>
  12. static llama_grammar* build_grammar(const std::string & grammar_str) {
  13. auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  14. // Ensure we parsed correctly
  15. assert(!parsed_grammar.rules.empty());
  16. // Ensure we have a root node
  17. assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
  18. std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
  19. llama_grammar* grammar = llama_grammar_init(
  20. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  21. return grammar;
  22. }
  23. static bool match_string(const std::string & input, llama_grammar* grammar) {
  24. auto decoded = decode_utf8(input, {});
  25. const auto & code_points = decoded.first;
  26. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  27. auto prev_stacks = grammar->stacks;
  28. llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
  29. if (grammar->stacks.empty()) {
  30. // no stacks means that the grammar failed to match at this point
  31. return false;
  32. }
  33. }
  34. for (const auto & stack : grammar->stacks) {
  35. if (stack.empty()) {
  36. // An empty stack means that the grammar has been completed
  37. return true;
  38. }
  39. }
  40. return false;
  41. }
  42. static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
  43. fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
  44. fflush(stderr);
  45. auto grammar = build_grammar(grammar_str);
  46. // Save the original grammar stacks so that we can reset after every new string we want to test
  47. auto original_stacks = grammar->stacks;
  48. fprintf(stderr, " 🔵 Valid strings:\n");
  49. // Passing strings
  50. for (const auto & test_string : passing_strings) {
  51. fprintf(stderr, " \"%s\" ", test_string.c_str());
  52. fflush(stderr);
  53. bool matched = match_string(test_string, grammar);
  54. if (!matched) {
  55. fprintf(stderr, "❌ (failed to match)\n");
  56. } else {
  57. fprintf(stdout, "✅︎\n");
  58. }
  59. assert(matched);
  60. // Reset the grammar stacks
  61. grammar->stacks = original_stacks;
  62. }
  63. fprintf(stderr, " 🟠 Invalid strings:\n");
  64. // Failing strings
  65. for (const auto & test_string : failing_strings) {
  66. fprintf(stderr, " \"%s\" ", test_string.c_str());
  67. fflush(stderr);
  68. bool matched = match_string(test_string, grammar);
  69. if (matched) {
  70. fprintf(stderr, "❌ (incorrectly matched)\n");
  71. } else {
  72. fprintf(stdout, "✅︎\n");
  73. }
  74. assert(!matched);
  75. // Reset the grammar stacks
  76. grammar->stacks = original_stacks;
  77. }
  78. // Clean up allocated memory
  79. llama_grammar_free(grammar);
  80. }
  81. static void test_simple_grammar() {
  82. // Test case for a simple grammar
  83. test_grammar(
  84. "simple grammar",
  85. R"""(
  86. root ::= expr
  87. expr ::= term ("+" term)*
  88. term ::= number
  89. number ::= [0-9]+)""",
  90. // Passing strings
  91. {
  92. "42",
  93. "1+2+3+4+5",
  94. "123+456",
  95. },
  96. // Failing strings
  97. {
  98. "+",
  99. "/ 3",
  100. "1+2+3+4+5+",
  101. "12a45",
  102. }
  103. );
  104. }
  105. static void test_complex_grammar() {
  106. // Test case for a more complex grammar, with both failure strings and success strings
  107. test_grammar(
  108. "medium complexity grammar",
  109. // Grammar
  110. R"""(
  111. root ::= expression
  112. expression ::= term ws (("+"|"-") ws term)*
  113. term ::= factor ws (("*"|"/") ws factor)*
  114. factor ::= number | variable | "(" expression ")" | function-call
  115. number ::= [0-9]+
  116. variable ::= [a-zA-Z_][a-zA-Z0-9_]*
  117. function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
  118. ws ::= [ \t\n\r]?)""",
  119. // Passing strings
  120. {
  121. "42",
  122. "1*2*3*4*5",
  123. "x",
  124. "x+10",
  125. "x1+y2",
  126. "(a+b)*(c-d)",
  127. "func()",
  128. "func(x,y+2)",
  129. "a*(b+c)-d/e",
  130. "f(g(x),h(y,z))",
  131. "x + 10",
  132. "x1 + y2",
  133. "(a + b) * (c - d)",
  134. "func()",
  135. "func(x, y + 2)",
  136. "a * (b + c) - d / e",
  137. "f(g(x), h(y, z))",
  138. "123+456",
  139. "123*456*789-123/456+789*123",
  140. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
  141. },
  142. // Failing strings
  143. {
  144. "+",
  145. "/ 3x",
  146. "x + + y",
  147. "a * / b",
  148. "func(,)",
  149. "func(x y)",
  150. "(a + b",
  151. "x + y)",
  152. "a + b * (c - d",
  153. "42 +",
  154. "x +",
  155. "x + 10 +",
  156. "(a + b) * (c - d",
  157. "func(",
  158. "func(x, y + 2",
  159. "a * (b + c) - d /",
  160. "f(g(x), h(y, z)",
  161. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
  162. }
  163. );
  164. }
  165. static void test_quantifiers() {
  166. // A collection of tests to exercise * + and ? quantifiers
  167. test_grammar(
  168. "* quantifier",
  169. // Grammar
  170. R"""(root ::= "a"*)""",
  171. // Passing strings
  172. {
  173. "",
  174. "a",
  175. "aaaaa",
  176. "aaaaaaaaaaaaaaaaaa",
  177. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  178. },
  179. // Failing strings
  180. {
  181. "b",
  182. "ab",
  183. "aab",
  184. "ba",
  185. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
  186. }
  187. );
  188. test_grammar(
  189. "+ quantifier",
  190. // Grammar
  191. R"""(root ::= "a"+)""",
  192. // Passing strings
  193. {
  194. "a",
  195. "aaaaa",
  196. "aaaaaaaaaaaaaaaaaa",
  197. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  198. },
  199. // Failing strings
  200. {
  201. "",
  202. "b",
  203. "ab",
  204. "aab",
  205. "ba",
  206. "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
  207. }
  208. );
  209. test_grammar(
  210. "? quantifier",
  211. // Grammar
  212. R"""(root ::= "a"?)""",
  213. // Passing strings
  214. {
  215. "",
  216. "a"
  217. },
  218. // Failing strings
  219. {
  220. "b",
  221. "ab",
  222. "aa",
  223. "ba",
  224. }
  225. );
  226. test_grammar(
  227. "mixed quantifiers",
  228. // Grammar
  229. R"""(
  230. root ::= cons+ vowel* cons? (vowel cons)*
  231. vowel ::= [aeiouy]
  232. cons ::= [bcdfghjklmnpqrstvwxyz]
  233. )""",
  234. // Passing strings
  235. {
  236. "yes",
  237. "no",
  238. "noyes",
  239. "crwth",
  240. "four",
  241. "bryyyy",
  242. },
  243. // Failing strings
  244. {
  245. "yess",
  246. "yesno",
  247. "forty",
  248. "catyyy",
  249. }
  250. );
  251. }
  252. static void test_failure_missing_root() {
  253. fprintf(stderr, "⚫ Testing missing root node:\n");
  254. // Test case for a grammar that is missing a root rule
  255. const std::string grammar_str = R"""(rot ::= expr
  256. expr ::= term ("+" term)*
  257. term ::= number
  258. number ::= [0-9]+)""";
  259. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  260. // Ensure we parsed correctly
  261. assert(!parsed_grammar.rules.empty());
  262. // Ensure we do NOT have a root node
  263. assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
  264. fprintf(stderr, " ✅︎ Passed\n");
  265. }
  266. static void test_failure_missing_reference() {
  267. fprintf(stderr, "⚫ Testing missing reference node:\n");
  268. // Test case for a grammar that is missing a referenced rule
  269. const std::string grammar_str =
  270. R"""(root ::= expr
  271. expr ::= term ("+" term)*
  272. term ::= numero
  273. number ::= [0-9]+)""";
  274. fprintf(stderr, " Expected error: ");
  275. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  276. // Ensure we did NOT parsed correctly
  277. assert(parsed_grammar.rules.empty());
  278. fprintf(stderr, " End of expected error.\n");
  279. fprintf(stderr, " ✅︎ Passed\n");
  280. }
  281. int main() {
  282. fprintf(stdout, "Running grammar integration tests...\n");
  283. test_simple_grammar();
  284. test_complex_grammar();
  285. test_quantifiers();
  286. test_failure_missing_root();
  287. test_failure_missing_reference();
  288. fprintf(stdout, "All tests passed.\n");
  289. return 0;
  290. }