| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- #ifdef NDEBUG
- #undef NDEBUG
- #endif
- #include "llama.cpp" // TODO: not great
- #include "grammar-parser.h"
- #include <cassert>
- int main()
- {
- grammar_parser::parse_state parsed_grammar;
- std::vector<std::pair<std::string, uint32_t>> expected = {
- {"expr", 2},
- {"expr_6", 6},
- {"expr_7", 7},
- {"ident", 8},
- {"ident_10", 10},
- {"num", 9},
- {"num_11", 11},
- {"root", 0},
- {"root_1", 1},
- {"root_5", 5},
- {"term", 4},
- {"ws", 3},
- {"ws_12", 12},
- };
- std::vector<std::vector<llama_grammar_element>> expected_rules = {
- {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
- {
- {LLAMA_GRETYPE_RULE_REF, 2},
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_RULE_REF, 4},
- {LLAMA_GRETYPE_CHAR, 10},
- {LLAMA_GRETYPE_END, 0},
- },
- {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
- {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
- {
- {LLAMA_GRETYPE_RULE_REF, 8},
- {LLAMA_GRETYPE_ALT, 0},
- {LLAMA_GRETYPE_RULE_REF, 9},
- {LLAMA_GRETYPE_ALT, 0},
- {LLAMA_GRETYPE_CHAR, 40},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_RULE_REF, 2},
- {LLAMA_GRETYPE_CHAR, 41},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_END, 0},
- },
- {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
- {
- {LLAMA_GRETYPE_CHAR, 45},
- {LLAMA_GRETYPE_CHAR_ALT, 43},
- {LLAMA_GRETYPE_CHAR_ALT, 42},
- {LLAMA_GRETYPE_CHAR_ALT, 47},
- {LLAMA_GRETYPE_RULE_REF, 4},
- {LLAMA_GRETYPE_END, 0},
- },
- {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
- {
- {LLAMA_GRETYPE_CHAR, 97},
- {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
- {LLAMA_GRETYPE_RULE_REF, 10},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_END, 0},
- },
- {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
- {
- {LLAMA_GRETYPE_CHAR, 97},
- {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
- {LLAMA_GRETYPE_CHAR_ALT, 48},
- {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
- {LLAMA_GRETYPE_CHAR_ALT, 95},
- {LLAMA_GRETYPE_RULE_REF, 10},
- {LLAMA_GRETYPE_ALT, 0},
- {LLAMA_GRETYPE_END, 0},
- },
- {
- {LLAMA_GRETYPE_CHAR, 48},
- {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
- {LLAMA_GRETYPE_RULE_REF, 11},
- {LLAMA_GRETYPE_ALT, 0},
- {LLAMA_GRETYPE_CHAR, 48},
- {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
- {LLAMA_GRETYPE_END, 0},
- },
- {
- {LLAMA_GRETYPE_CHAR, 32},
- {LLAMA_GRETYPE_CHAR_ALT, 9},
- {LLAMA_GRETYPE_CHAR_ALT, 10},
- {LLAMA_GRETYPE_RULE_REF, 12},
- {LLAMA_GRETYPE_ALT, 0},
- {LLAMA_GRETYPE_END, 0},
- },
- };
- for (auto pair : expected)
- {
- parsed_grammar.symbol_ids[pair.first] = pair.second;
- }
- for (auto rule : expected_rules)
- {
- parsed_grammar.rules.emplace_back();
- for (auto element : rule)
- {
- parsed_grammar.rules.back().push_back(element);
- }
- }
- llama_grammar *grammar = NULL;
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- if (grammar == nullptr)
- {
- throw std::runtime_error("Failed to initialize llama_grammar");
- }
- std::vector<std::vector<llama_grammar_element>> expected_stacks = {
- {
- {LLAMA_GRETYPE_RULE_REF, 5},
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 97},
- },
- {
- {LLAMA_GRETYPE_RULE_REF, 5},
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_CHAR, 48},
- },
- {
- {LLAMA_GRETYPE_RULE_REF, 5},
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_CHAR, 48},
- },
- {
- {LLAMA_GRETYPE_RULE_REF, 5},
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 40},
- },
- {
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 97},
- },
- {
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_CHAR, 48},
- },
- {
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_RULE_REF, 3},
- {LLAMA_GRETYPE_CHAR, 48},
- },
- {
- {LLAMA_GRETYPE_CHAR, 61},
- {LLAMA_GRETYPE_RULE_REF, 7},
- {LLAMA_GRETYPE_CHAR, 40},
- }};
- auto index = 0;
- for (auto stack : grammar->stacks)
- {
- // compare stack to expected_stack
- for (uint32_t i = 0; i < stack.size(); i++)
- {
- auto element = stack[i];
- auto expected_element = expected_stacks[index][i];
- // pretty print error message before asserting
- if (expected_element.type != element->type || expected_element.value != element->value)
- {
- fprintf(stderr, "index: %d\n", index);
- fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
- fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
- fprintf(stderr, "expected_element != actual_element\n");
- }
- assert(expected_element.type == element->type && expected_element.value == element->value);
- }
- index++;
- }
- std::vector<llama_grammar_candidate> next_candidates;
- next_candidates.resize(24);
- for (size_t i = 0; i < 24; ++i)
- {
- uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
- cp[0] = 37 + i;
- cp[1] = 0;
- next_candidates[i] = {i, cp, {}};
- }
- std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {11, 48},
- {12, 49},
- {13, 50},
- {14, 51},
- {15, 52},
- {16, 53},
- {17, 54},
- {18, 55},
- {19, 56},
- {20, 57},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {11, 48},
- {12, 49},
- {13, 50},
- {14, 51},
- {15, 52},
- {16, 53},
- {17, 54},
- {18, 55},
- {19, 56},
- {20, 57},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {11, 48},
- {12, 49},
- {13, 50},
- {14, 51},
- {15, 52},
- {16, 53},
- {17, 54},
- {18, 55},
- {19, 56},
- {20, 57},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {3, 40},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- {
- {0, 37},
- {1, 38},
- {2, 39},
- {4, 41},
- {5, 42},
- {6, 43},
- {7, 44},
- {8, 45},
- {9, 46},
- {10, 47},
- {11, 48},
- {12, 49},
- {13, 50},
- {14, 51},
- {15, 52},
- {16, 53},
- {17, 54},
- {18, 55},
- {19, 56},
- {20, 57},
- {21, 58},
- {22, 59},
- {23, 60},
- },
- };
- std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
- std::vector<std::vector<llama_grammar_candidate>> all_rejects;
- for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
- {
- rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
- all_rejects.push_back(rejects);
- }
- index = 0;
- for (auto rej : all_rejects)
- {
- for (uint32_t i = 0; i < rej.size(); i++)
- {
- auto element = rej[i];
- auto expected_element = expected_reject[index][i];
- assert(element.index == expected_element.first && *element.code_points == expected_element.second);
- }
- index++;
- }
- for (auto &candidate : next_candidates)
- {
- delete[] candidate.code_points;
- candidate.code_points = nullptr;
- }
- delete grammar;
- return 0;
- }
|