| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482 |
- #ifdef NDEBUG
- #undef NDEBUG
- #endif
- #define LLAMA_API_INTERNAL
- #include "ggml.h"
- #include "llama.h"
- #include "grammar-parser.h"
- #include "unicode.h"
- #include <cassert>
- #include <string>
- #include <vector>
- static llama_grammar* build_grammar(const std::string & grammar_str) {
- auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
- // Ensure we parsed correctly
- assert(!parsed_grammar.rules.empty());
- // Ensure we have a root node
- assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
- std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
- llama_grammar* grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- return grammar;
- }
- static bool test_build_grammar_fails(const std::string & grammar_str) {
- fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
- bool grammar_fails = false;
- try {
- build_grammar(grammar_str);
- fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
- } catch (const std::exception & err) {
- grammar_fails = true;
- fprintf(stdout, " ✅︎\n");
- }
- return grammar_fails;
- }
- static bool match_string(const std::string & input, llama_grammar* grammar) {
- auto decoded = decode_utf8(input, {});
- const auto & code_points = decoded.first;
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- auto prev_stacks = grammar->stacks;
- llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
- if (grammar->stacks.empty()) {
- // no stacks means that the grammar failed to match at this point
- return false;
- }
- }
- for (const auto & stack : grammar->stacks) {
- if (stack.empty()) {
- // An empty stack means that the grammar has been completed
- return true;
- }
- }
- return false;
- }
- static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
- fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
- fflush(stderr);
- auto grammar = build_grammar(grammar_str);
- // Save the original grammar stacks so that we can reset after every new string we want to test
- auto original_stacks = grammar->stacks;
- fprintf(stderr, " 🔵 Valid strings:\n");
- // Passing strings
- for (const auto & test_string : passing_strings) {
- fprintf(stderr, " \"%s\" ", test_string.c_str());
- fflush(stderr);
- bool matched = match_string(test_string, grammar);
- if (!matched) {
- fprintf(stderr, "❌ (failed to match)\n");
- } else {
- fprintf(stdout, "✅︎\n");
- }
- assert(matched);
- // Reset the grammar stacks
- grammar->stacks = original_stacks;
- }
- fprintf(stderr, " 🟠 Invalid strings:\n");
- // Failing strings
- for (const auto & test_string : failing_strings) {
- fprintf(stderr, " \"%s\" ", test_string.c_str());
- fflush(stderr);
- bool matched = match_string(test_string, grammar);
- if (matched) {
- fprintf(stderr, "❌ (incorrectly matched)\n");
- } else {
- fprintf(stdout, "✅︎\n");
- }
- assert(!matched);
- // Reset the grammar stacks
- grammar->stacks = original_stacks;
- }
- // Clean up allocated memory
- llama_grammar_free(grammar);
- }
- static void test_simple_grammar() {
- // Test case for a simple grammar
- test_grammar(
- "simple grammar",
- R"""(
- root ::= expr
- expr ::= term ("+" term)*
- term ::= number
- number ::= [0-9]+)""",
- // Passing strings
- {
- "42",
- "1+2+3+4+5",
- "123+456",
- },
- // Failing strings
- {
- "+",
- "/ 3",
- "1+2+3+4+5+",
- "12a45",
- }
- );
- }
- static void test_complex_grammar() {
- // Test case for a more complex grammar, with both failure strings and success strings
- test_grammar(
- "medium complexity grammar",
- // Grammar
- R"""(
- root ::= expression
- expression ::= term ws (("+"|"-") ws term)*
- term ::= factor ws (("*"|"/") ws factor)*
- factor ::= number | variable | "(" expression ")" | function-call
- number ::= [0-9]+
- variable ::= [a-zA-Z_][a-zA-Z0-9_]*
- function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
- ws ::= [ \t\n\r]?)""",
- // Passing strings
- {
- "42",
- "1*2*3*4*5",
- "x",
- "x+10",
- "x1+y2",
- "(a+b)*(c-d)",
- "func()",
- "func(x,y+2)",
- "a*(b+c)-d/e",
- "f(g(x),h(y,z))",
- "x + 10",
- "x1 + y2",
- "(a + b) * (c - d)",
- "func()",
- "func(x, y + 2)",
- "a * (b + c) - d / e",
- "f(g(x), h(y, z))",
- "123+456",
- "123*456*789-123/456+789*123",
- "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
- },
- // Failing strings
- {
- "+",
- "/ 3x",
- "x + + y",
- "a * / b",
- "func(,)",
- "func(x y)",
- "(a + b",
- "x + y)",
- "a + b * (c - d",
- "42 +",
- "x +",
- "x + 10 +",
- "(a + b) * (c - d",
- "func(",
- "func(x, y + 2",
- "a * (b + c) - d /",
- "f(g(x), h(y, z)",
- "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
- }
- );
- }
- static void test_special_chars() {
- // A collection of tests to exercise special characters such as "."
- test_grammar(
- "special characters",
- // Grammar
- R"""(
- root ::= ... "abc" ...
- )""",
- // Passing strings
- {
- "abcabcabc",
- "aaaabcccc",
- // NOTE: Also ensures that multi-byte characters still count as a single character
- "🔵🟠✅abc❌🟠🔵"
- },
- // Failing strings
- {
- "aaabcccc",
- "aaaaabcccc",
- "aaaabccc",
- "aaaabccccc",
- "🔵🟠✅❌abc❌✅🟠🔵"
- "🔵🟠abc🟠🔵"
- }
- );
- }
- static void test_quantifiers() {
- // A collection of tests to exercise * + and ? quantifiers
- test_grammar(
- "* quantifier",
- // Grammar
- R"""(root ::= "a"*)""",
- // Passing strings
- {
- "",
- "a",
- "aaaaa",
- "aaaaaaaaaaaaaaaaaa",
- "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- },
- // Failing strings
- {
- "b",
- "ab",
- "aab",
- "ba",
- "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
- }
- );
- test_grammar(
- "+ quantifier",
- // Grammar
- R"""(root ::= "a"+)""",
- // Passing strings
- {
- "a",
- "aaaaa",
- "aaaaaaaaaaaaaaaaaa",
- "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
- },
- // Failing strings
- {
- "",
- "b",
- "ab",
- "aab",
- "ba",
- "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
- }
- );
- test_grammar(
- "? quantifier",
- // Grammar
- R"""(root ::= "a"?)""",
- // Passing strings
- {
- "",
- "a"
- },
- // Failing strings
- {
- "b",
- "ab",
- "aa",
- "ba",
- }
- );
- test_grammar(
- "mixed quantifiers",
- // Grammar
- R"""(
- root ::= cons+ vowel* cons? (vowel cons)*
- vowel ::= [aeiouy]
- cons ::= [bcdfghjklmnpqrstvwxyz]
- )""",
- // Passing strings
- {
- "yes",
- "no",
- "noyes",
- "crwth",
- "four",
- "bryyyy",
- },
- // Failing strings
- {
- "yess",
- "yesno",
- "forty",
- "catyyy",
- }
- );
- test_grammar(
- "simple exact repetition",
- // Grammar
- R"""(
- root ::= [ab]{4}
- )""",
- // Passing strings
- {
- "aaaa",
- "bbbb",
- "abab",
- },
- // Failing strings
- {
- "a",
- "b",
- "aaaaa",
- }
- );
- test_grammar(
- "simple min repetition",
- // Grammar
- R"""(
- root ::= [ab]{4,}
- )""",
- // Passing strings
- {
- "aaaa",
- "aaaaab",
- "bbbb",
- "ababab",
- },
- // Failing strings
- {
- "",
- "aba",
- }
- );
- test_grammar(
- "simple max repetition",
- // Grammar
- R"""(
- root ::= [ab]{0,4}
- )""",
- // Passing strings
- {
- "",
- "a",
- "aa",
- "aaa",
- "aaab",
- },
- // Failing strings
- {
- "aaaaa",
- }
- );
- test_grammar(
- "min / max repetition",
- // Grammar
- R"""(
- root ::= ("0x" [A-F0-9]{2} " "?){3,5}
- )""",
- // Passing strings
- {
- "0xFF 0x12 0xAB",
- "0xFF 0x12 0xAB 0x00 0x00",
- },
- // Failing strings
- {
- "",
- "0xFF",
- "0xFF 0x12",
- "0xFF 0x12 0xAB 0x00 0x00 0x00",
- }
- );
- }
- static void test_failure_missing_root() {
- fprintf(stderr, "⚫ Testing missing root node:\n");
- // Test case for a grammar that is missing a root rule
- const std::string grammar_str = R"""(rot ::= expr
- expr ::= term ("+" term)*
- term ::= number
- number ::= [0-9]+)""";
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
- // Ensure we parsed correctly
- assert(!parsed_grammar.rules.empty());
- // Ensure we do NOT have a root node
- assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
- fprintf(stderr, " ✅︎ Passed\n");
- }
- static void test_failure_missing_reference() {
- fprintf(stderr, "⚫ Testing missing reference node:\n");
- // Test case for a grammar that is missing a referenced rule
- const std::string grammar_str =
- R"""(root ::= expr
- expr ::= term ("+" term)*
- term ::= numero
- number ::= [0-9]+)""";
- fprintf(stderr, " Expected error: ");
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
- // Ensure we did NOT parsed correctly
- assert(parsed_grammar.rules.empty());
- fprintf(stderr, " End of expected error.\n");
- fprintf(stderr, " ✅︎ Passed\n");
- }
- static void test_failure_left_recursion() {
- fprintf(stderr, "⚫ Testing left recursion detection:\n");
- // Test simple left recursion detection
- const std::string simple_str = R"""(root ::= "a" | root "a")""";
- assert(test_build_grammar_fails(simple_str));
- // Test more complicated left recursion detection
- const std::string medium_str = R"""(
- root ::= asdf
- asdf ::= "a" | asdf "a"
- )""";
- assert(test_build_grammar_fails(medium_str));
- // Test even more complicated left recursion detection
- const std::string hard_str = R"""(
- root ::= asdf
- asdf ::= "a" | foo "b"
- foo ::= "c" | asdf "d" | "e")""";
- assert(test_build_grammar_fails(hard_str));
- // Test yet even more complicated left recursion detection
- const std::string hardest_str = R"""(
- root ::= asdf
- asdf ::= "a" | foo "b"
- foo ::= "c" | empty asdf "d" | "e"
- empty ::= "blah" | )""";
- assert(test_build_grammar_fails(hardest_str));
- fprintf(stderr, " ✅︎ Passed\n");
- }
- int main() {
- fprintf(stdout, "Running grammar integration tests...\n");
- test_simple_grammar();
- test_complex_grammar();
- test_special_chars();
- test_quantifiers();
- test_failure_missing_root();
- test_failure_missing_reference();
- test_failure_left_recursion();
- fprintf(stdout, "All tests passed.\n");
- return 0;
- }
|