test-grammar-integration.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #ifdef NDEBUG
  2. #undef NDEBUG
  3. #endif
  4. #define LLAMA_API_INTERNAL
  5. #include "ggml.h"
  6. #include "llama.h"
  7. #include "grammar-parser.h"
  8. #include "unicode.h"
  9. #include <cassert>
  10. #include <string>
  11. static void test_simple_grammar() {
  12. // Test case for a simple grammar
  13. const std::string grammar_str = R"""(root ::= expr
  14. expr ::= term ("+" term)*
  15. term ::= number
  16. number ::= [0-9]+)""";
  17. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  18. // Ensure we parsed correctly
  19. assert(!parsed_grammar.rules.empty());
  20. // Ensure we have a root node
  21. assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
  22. std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
  23. llama_grammar* grammar = llama_grammar_init(
  24. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  25. std::string input = "123+456";
  26. auto decoded = decode_utf8(input, {});
  27. const auto & code_points = decoded.first;
  28. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  29. auto prev_stacks = grammar->stacks;
  30. llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
  31. assert(!grammar->stacks.empty());
  32. }
  33. bool completed_grammar = false;
  34. for (const auto & stack : grammar->stacks) {
  35. if (stack.empty()) {
  36. completed_grammar = true;
  37. break;
  38. }
  39. }
  40. assert(completed_grammar);
  41. // Clean up allocated memory
  42. llama_grammar_free(grammar);
  43. }
  44. static void test_complex_grammar() {
  45. // Test case for a more complex grammar, with both failure strings and success strings
  46. const std::string grammar_str = R"""(root ::= expression
  47. expression ::= term ws (("+"|"-") ws term)*
  48. term ::= factor ws (("*"|"/") ws factor)*
  49. factor ::= number | variable | "(" expression ")" | function-call
  50. number ::= [0-9]+
  51. variable ::= [a-zA-Z_][a-zA-Z0-9_]*
  52. function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
  53. ws ::= [ \t\n\r]?)""";
  54. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  55. // Ensure we parsed correctly
  56. assert(!parsed_grammar.rules.empty());
  57. // Ensure we have a root node
  58. assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
  59. std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
  60. llama_grammar* grammar = llama_grammar_init(
  61. grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
  62. // Save the original grammar stacks so that we can reset after every new string we want to test
  63. auto original_stacks = grammar->stacks;
  64. // Test a few strings
  65. std::vector<std::string> test_strings_pass = {
  66. "42",
  67. "1*2*3*4*5",
  68. "x",
  69. "x+10",
  70. "x1+y2",
  71. "(a+b)*(c-d)",
  72. "func()",
  73. "func(x,y+2)",
  74. "a*(b+c)-d/e",
  75. "f(g(x),h(y,z))",
  76. "x + 10",
  77. "x1 + y2",
  78. "(a + b) * (c - d)",
  79. "func()",
  80. "func(x, y + 2)",
  81. "a * (b + c) - d / e",
  82. "f(g(x), h(y, z))",
  83. "123+456",
  84. "123*456*789-123/456+789*123",
  85. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
  86. };
  87. std::vector<std::string> test_strings_fail = {
  88. "+",
  89. "/ 3x",
  90. "x + + y",
  91. "a * / b",
  92. "func(,)",
  93. "func(x y)",
  94. "(a + b",
  95. "x + y)",
  96. "a + b * (c - d",
  97. "42 +",
  98. "x +",
  99. "x + 10 +",
  100. "(a + b) * (c - d",
  101. "func(",
  102. "func(x, y + 2",
  103. "a * (b + c) - d /",
  104. "f(g(x), h(y, z)",
  105. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
  106. };
  107. // Passing strings
  108. for (const auto & test_string : test_strings_pass) {
  109. auto decoded = decode_utf8(test_string, {});
  110. const auto & code_points = decoded.first;
  111. int pos = 0;
  112. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  113. ++pos;
  114. auto prev_stacks = grammar->stacks;
  115. llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
  116. // Expect that each code point will not cause the grammar to fail
  117. if (grammar->stacks.empty()) {
  118. fprintf(stdout, "Error at position %d\n", pos);
  119. fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
  120. fprintf(stderr, "Input string is %s:\n", test_string.c_str());
  121. }
  122. assert(!grammar->stacks.empty());
  123. }
  124. bool completed_grammar = false;
  125. for (const auto & stack : grammar->stacks) {
  126. if (stack.empty()) {
  127. completed_grammar = true;
  128. break;
  129. }
  130. }
  131. assert(completed_grammar);
  132. // Reset the grammar stacks
  133. grammar->stacks = original_stacks;
  134. }
  135. // Failing strings
  136. for (const auto & test_string : test_strings_fail) {
  137. auto decoded = decode_utf8(test_string, {});
  138. const auto & code_points = decoded.first;
  139. bool parse_failed = false;
  140. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  141. auto prev_stacks = grammar->stacks;
  142. llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
  143. if (grammar->stacks.empty()) {
  144. parse_failed = true;
  145. break;
  146. }
  147. assert(!grammar->stacks.empty());
  148. }
  149. bool completed_grammar = false;
  150. for (const auto & stack : grammar->stacks) {
  151. if (stack.empty()) {
  152. completed_grammar = true;
  153. break;
  154. }
  155. }
  156. // Ensure that the grammar is not completed, or that each string failed to match as-expected
  157. assert((!completed_grammar) || parse_failed);
  158. // Reset the grammar stacks
  159. grammar->stacks = original_stacks;
  160. }
  161. // Clean up allocated memory
  162. llama_grammar_free(grammar);
  163. }
  164. static void test_failure_missing_root() {
  165. // Test case for a grammar that is missing a root rule
  166. const std::string grammar_str = R"""(rot ::= expr
  167. expr ::= term ("+" term)*
  168. term ::= number
  169. number ::= [0-9]+)""";
  170. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  171. // Ensure we parsed correctly
  172. assert(!parsed_grammar.rules.empty());
  173. // Ensure we do NOT have a root node
  174. assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
  175. }
  176. static void test_failure_missing_reference() {
  177. // Test case for a grammar that is missing a referenced rule
  178. const std::string grammar_str = R"""(root ::= expr
  179. expr ::= term ("+" term)*
  180. term ::= numero
  181. number ::= [0-9]+)""";
  182. fprintf(stderr, "Expected error: ");
  183. grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
  184. // Ensure we did NOT parsed correctly
  185. assert(parsed_grammar.rules.empty());
  186. fprintf(stderr, "End of expected error. Test successful.\n");
  187. }
  188. int main() {
  189. test_simple_grammar();
  190. test_complex_grammar();
  191. test_failure_missing_root();
  192. test_failure_missing_reference();
  193. return 0;
  194. }