llama-grammar.cpp 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437
  1. #include "llama-grammar.h"
  2. #include "llama-impl.h"
  3. #include "llama-vocab.h"
  4. #include "llama-sampling.h"
  5. #include <cmath>
  6. #include <algorithm>
  7. #include <cstdint>
  8. #include <stdexcept>
  9. #define MAX_REPETITION_THRESHOLD 2000
  10. //
  11. // helpers
  12. //
  13. // NOTE: assumes valid utf8 (but checks for overrun)
  14. static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
  15. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  16. uint8_t first_byte = static_cast<uint8_t>(*src);
  17. uint8_t highbits = first_byte >> 4;
  18. int len = lookup[highbits];
  19. uint8_t mask = (1 << (8 - len)) - 1;
  20. uint32_t value = first_byte & mask;
  21. const char * end = src + len; // may overrun!
  22. const char * pos = src + 1;
  23. for ( ; pos < end && *pos; pos++) {
  24. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  25. }
  26. return std::make_pair(value, pos);
  27. }
  28. static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
  29. const std::string & src,
  30. llama_partial_utf8 partial_start) {
  31. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
  32. const char * pos = src.c_str();
  33. std::vector<uint32_t> code_points;
  34. // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
  35. code_points.reserve(src.size() + 1);
  36. uint32_t value = partial_start.value;
  37. int n_remain = partial_start.n_remain;
  38. // continue previous decode, if applicable
  39. while (*pos != 0 && n_remain > 0) {
  40. uint8_t next_byte = static_cast<uint8_t>(*pos);
  41. if ((next_byte >> 6) != 2) {
  42. // invalid sequence, abort
  43. code_points.push_back(0);
  44. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
  45. }
  46. value = (value << 6) + (next_byte & 0x3F);
  47. ++pos;
  48. --n_remain;
  49. }
  50. if (partial_start.n_remain > 0 && n_remain == 0) {
  51. code_points.push_back(value);
  52. }
  53. // decode any subsequent utf-8 sequences, which may end in an incomplete one
  54. while (*pos != 0) {
  55. uint8_t first_byte = static_cast<uint8_t>(*pos);
  56. uint8_t highbits = first_byte >> 4;
  57. n_remain = lookup[highbits] - 1;
  58. if (n_remain < 0) {
  59. // invalid sequence, abort
  60. code_points.clear();
  61. code_points.push_back(0);
  62. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
  63. }
  64. uint8_t mask = (1 << (7 - n_remain)) - 1;
  65. value = first_byte & mask;
  66. ++pos;
  67. while (*pos != 0 && n_remain > 0) {
  68. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  69. ++pos;
  70. --n_remain;
  71. }
  72. if (n_remain == 0) {
  73. code_points.push_back(value);
  74. }
  75. }
  76. code_points.push_back(0);
  77. return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
  78. }
  79. static bool is_digit_char(char c) {
  80. return '0' <= c && c <= '9';
  81. }
  82. static bool is_word_char(char c) {
  83. return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
  84. }
  85. static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
  86. const char * pos = src;
  87. const char * end = src + size;
  88. uint32_t value = 0;
  89. for ( ; pos < end && *pos; pos++) {
  90. value <<= 4;
  91. char c = *pos;
  92. if ('a' <= c && c <= 'f') {
  93. value += c - 'a' + 10;
  94. } else if ('A' <= c && c <= 'F') {
  95. value += c - 'A' + 10;
  96. } else if ('0' <= c && c <= '9') {
  97. value += c - '0';
  98. } else {
  99. break;
  100. }
  101. }
  102. if (pos != end) {
  103. throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
  104. }
  105. return std::make_pair(value, pos);
  106. }
  107. static const char * parse_space(const char * src, bool newline_ok) {
  108. const char * pos = src;
  109. while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
  110. (newline_ok && (*pos == '\r' || *pos == '\n'))) {
  111. if (*pos == '#') {
  112. while (*pos && *pos != '\r' && *pos != '\n') {
  113. pos++;
  114. }
  115. } else {
  116. pos++;
  117. }
  118. }
  119. return pos;
  120. }
  121. static const char * parse_name(const char * src) {
  122. const char * pos = src;
  123. while (is_word_char(*pos)) {
  124. pos++;
  125. }
  126. if (pos == src) {
  127. throw std::runtime_error(std::string("expecting name at ") + src);
  128. }
  129. return pos;
  130. }
  131. static const char * parse_int(const char * src) {
  132. const char * pos = src;
  133. while (is_digit_char(*pos)) {
  134. pos++;
  135. }
  136. if (pos == src) {
  137. throw std::runtime_error(std::string("expecting integer at ") + src);
  138. }
  139. return pos;
  140. }
  141. static std::pair<uint32_t, const char *> parse_char(const char * src) {
  142. if (*src == '\\') {
  143. switch (src[1]) {
  144. case 'x': return parse_hex(src + 2, 2);
  145. case 'u': return parse_hex(src + 2, 4);
  146. case 'U': return parse_hex(src + 2, 8);
  147. case 't': return std::make_pair('\t', src + 2);
  148. case 'r': return std::make_pair('\r', src + 2);
  149. case 'n': return std::make_pair('\n', src + 2);
  150. case '\\':
  151. case '"':
  152. case '[':
  153. case ']':
  154. return std::make_pair(src[1], src + 2);
  155. default:
  156. throw std::runtime_error(std::string("unknown escape at ") + src);
  157. }
  158. } else if (*src) {
  159. return decode_utf8(src);
  160. }
  161. throw std::runtime_error("unexpected end of input");
  162. }
  163. static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
  164. const char * pos = src;
  165. if (*pos != '<') {
  166. throw std::runtime_error(std::string("expecting '<' at ") + pos);
  167. }
  168. pos++;
  169. // Parse <[id]>
  170. if (*pos == '[') {
  171. pos++;
  172. const char * int_end = parse_int(pos);
  173. uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
  174. pos = int_end;
  175. if (*pos != ']') {
  176. throw std::runtime_error(std::string("expecting ']' at ") + pos);
  177. }
  178. pos++;
  179. if (*pos != '>') {
  180. throw std::runtime_error(std::string("expecting '>' at ") + pos);
  181. }
  182. pos++;
  183. return std::make_pair(token_id, pos);
  184. }
  185. if (vocab == nullptr) {
  186. throw std::runtime_error(std::string("no vocab to parse token at ") + src);
  187. }
  188. // Parse <token> and tokenize to obtain the token id
  189. while (*pos != 0 && *pos != '>') {
  190. pos++;
  191. }
  192. if (*pos != '>') {
  193. throw std::runtime_error(std::string("expecting '>' at ") + pos);
  194. }
  195. pos++;
  196. llama_token tokens[2];
  197. int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
  198. if (n_tokens != 1) {
  199. // must tokenize to exactly 1 token
  200. throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
  201. }
  202. return std::make_pair(tokens[0], pos);
  203. }
  204. static void print_grammar_char(FILE * file, uint32_t c) {
  205. if (0x20 <= c && c <= 0x7f) {
  206. fprintf(file, "%c", static_cast<char>(c));
  207. } else {
  208. // cop out of encoding UTF-8
  209. fprintf(file, "<U+%04X>", c);
  210. }
  211. }
  212. static bool is_char_element(llama_grammar_element elem) {
  213. switch (elem.type) {
  214. case LLAMA_GRETYPE_CHAR: return true;
  215. case LLAMA_GRETYPE_CHAR_NOT: return true;
  216. case LLAMA_GRETYPE_CHAR_ALT: return true;
  217. case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
  218. case LLAMA_GRETYPE_CHAR_ANY: return true;
  219. default: return false;
  220. }
  221. }
  222. static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
  223. for (auto elem : rule) {
  224. switch (elem.type) {
  225. case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
  226. case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
  227. case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
  228. case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
  229. case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
  230. case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
  231. case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
  232. case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
  233. case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
  234. case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
  235. }
  236. switch (elem.type) {
  237. case LLAMA_GRETYPE_END:
  238. case LLAMA_GRETYPE_ALT:
  239. case LLAMA_GRETYPE_RULE_REF:
  240. fprintf(file, "(%u) ", elem.value);
  241. break;
  242. case LLAMA_GRETYPE_CHAR:
  243. case LLAMA_GRETYPE_CHAR_NOT:
  244. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  245. case LLAMA_GRETYPE_CHAR_ALT:
  246. case LLAMA_GRETYPE_CHAR_ANY:
  247. fprintf(file, "(\"");
  248. print_grammar_char(file, elem.value);
  249. fprintf(file, "\") ");
  250. break;
  251. case LLAMA_GRETYPE_TOKEN:
  252. fprintf(file, "<[");
  253. fprintf(file, "%u", elem.value);
  254. fprintf(file, "]> ");
  255. break;
  256. case LLAMA_GRETYPE_TOKEN_NOT:
  257. fprintf(file, "!");
  258. fprintf(file, "<[");
  259. fprintf(file, "%u", elem.value);
  260. fprintf(file, "]> ");
  261. break;
  262. }
  263. }
  264. fprintf(file, "\n");
  265. }
  266. static void print_rule(
  267. FILE * file,
  268. uint32_t rule_id,
  269. const llama_grammar_rule & rule,
  270. const std::map<uint32_t, std::string> & symbol_id_names) {
  271. if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
  272. throw std::runtime_error(
  273. "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
  274. }
  275. fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
  276. for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
  277. llama_grammar_element elem = rule[i];
  278. switch (elem.type) {
  279. case LLAMA_GRETYPE_END:
  280. throw std::runtime_error(
  281. "unexpected end of rule: " + std::to_string(rule_id) + "," +
  282. std::to_string(i));
  283. case LLAMA_GRETYPE_ALT:
  284. fprintf(file, "| ");
  285. break;
  286. case LLAMA_GRETYPE_RULE_REF:
  287. fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
  288. break;
  289. case LLAMA_GRETYPE_CHAR:
  290. fprintf(file, "[");
  291. print_grammar_char(file, elem.value);
  292. break;
  293. case LLAMA_GRETYPE_CHAR_NOT:
  294. fprintf(file, "[^");
  295. print_grammar_char(file, elem.value);
  296. break;
  297. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  298. if (i == 0 || !is_char_element(rule[i - 1])) {
  299. throw std::runtime_error(
  300. "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
  301. std::to_string(rule_id) + "," + std::to_string(i));
  302. }
  303. fprintf(file, "-");
  304. print_grammar_char(file, elem.value);
  305. break;
  306. case LLAMA_GRETYPE_CHAR_ALT:
  307. if (i == 0 || !is_char_element(rule[i - 1])) {
  308. throw std::runtime_error(
  309. "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
  310. std::to_string(rule_id) + "," + std::to_string(i));
  311. }
  312. print_grammar_char(file, elem.value);
  313. break;
  314. case LLAMA_GRETYPE_CHAR_ANY:
  315. fprintf(file, ".");
  316. break;
  317. case LLAMA_GRETYPE_TOKEN:
  318. fprintf(file, "<[");
  319. fprintf(file, "%u", elem.value);
  320. fprintf(file, "]> ");
  321. break;
  322. case LLAMA_GRETYPE_TOKEN_NOT:
  323. fprintf(file, "!");
  324. fprintf(file, "<[");
  325. fprintf(file, "%u", elem.value);
  326. fprintf(file, "]> ");
  327. break;
  328. }
  329. if (is_char_element(elem)) {
  330. switch (rule[i + 1].type) {
  331. case LLAMA_GRETYPE_CHAR_ALT:
  332. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  333. case LLAMA_GRETYPE_CHAR_ANY:
  334. break;
  335. default:
  336. fprintf(file, "] ");
  337. }
  338. }
  339. }
  340. fprintf(file, "\n");
  341. }
  342. //
  343. // implementation
  344. //
  345. uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
  346. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  347. auto result = symbol_ids.emplace(std::string(src, len), next_id);
  348. return result.first->second;
  349. }
  350. uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
  351. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  352. symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
  353. return next_id;
  354. }
  355. void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
  356. if (rules.size() <= rule_id) {
  357. rules.resize(rule_id + 1);
  358. }
  359. rules[rule_id] = rule;
  360. }
  361. const char * llama_grammar_parser::parse_alternates(
  362. const char * src,
  363. const std::string & rule_name,
  364. uint32_t rule_id,
  365. bool is_nested) {
  366. llama_grammar_rule rule;
  367. const char * pos = parse_sequence(src, rule_name, rule, is_nested);
  368. while (*pos == '|') {
  369. rule.push_back({LLAMA_GRETYPE_ALT, 0});
  370. pos = parse_space(pos + 1, true);
  371. pos = parse_sequence(pos, rule_name, rule, is_nested);
  372. }
  373. rule.push_back({LLAMA_GRETYPE_END, 0});
  374. add_rule(rule_id, rule);
  375. return pos;
  376. }
  377. const char * llama_grammar_parser::parse_sequence(
  378. const char * src,
  379. const std::string & rule_name,
  380. llama_grammar_rule & rule,
  381. bool is_nested) {
  382. size_t last_sym_start = rule.size();
  383. const char * pos = src;
  384. // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
  385. // (though it's technically the same as -1 now)
  386. auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
  387. bool no_max = max_times == UINT64_MAX;
  388. if (last_sym_start == rule.size()) {
  389. throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
  390. }
  391. // apply transformation to previous symbol (last_sym_start to end) according to
  392. // the following rewrite rules:
  393. // S{m,n} --> S S S (m times) S'(n-m)
  394. // S'(x) ::= S S'(x-1) |
  395. // (... n-m definitions of these S' rules ...)
  396. // S'(1) ::= S |
  397. // S{m,} --> S S S (m times) S'
  398. // S' ::= S S' |
  399. // S* --> S{0,}
  400. // --> S' ::= S S' |
  401. // S+ --> S{1,}
  402. // --> S S'
  403. // S' ::= S S' |
  404. // S? --> S{0,1}
  405. // --> S'
  406. // S' ::= S |
  407. llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
  408. if (min_times == 0) {
  409. rule.resize(last_sym_start);
  410. } else {
  411. // Repeat the previous elements (min_times - 1) times
  412. for (uint64_t i = 1; i < min_times; i++) {
  413. rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
  414. }
  415. }
  416. uint32_t last_rec_rule_id = 0;
  417. auto n_opt = no_max ? 1 : max_times - min_times;
  418. llama_grammar_rule rec_rule(prev_rule);
  419. for (uint64_t i = 0; i < n_opt; i++) {
  420. rec_rule.resize(prev_rule.size());
  421. uint32_t rec_rule_id = generate_symbol_id( rule_name);
  422. if (i > 0 || no_max) {
  423. rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
  424. }
  425. rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
  426. rec_rule.push_back({LLAMA_GRETYPE_END, 0});
  427. add_rule( rec_rule_id, rec_rule);
  428. last_rec_rule_id = rec_rule_id;
  429. }
  430. if (n_opt > 0) {
  431. rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
  432. }
  433. };
  434. while (*pos) {
  435. if (*pos == '"') { // literal string
  436. pos++;
  437. last_sym_start = rule.size();
  438. while (*pos != '"') {
  439. if (!*pos) {
  440. throw std::runtime_error("unexpected end of input");
  441. }
  442. auto char_pair = parse_char(pos);
  443. pos = char_pair.second;
  444. rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
  445. }
  446. pos = parse_space(pos + 1, is_nested);
  447. } else if (*pos == '[') { // char range(s)
  448. pos++;
  449. enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
  450. if (*pos == '^') {
  451. pos++;
  452. start_type = LLAMA_GRETYPE_CHAR_NOT;
  453. }
  454. last_sym_start = rule.size();
  455. while (*pos != ']') {
  456. if (!*pos) {
  457. throw std::runtime_error("unexpected end of input");
  458. }
  459. auto char_pair = parse_char(pos);
  460. pos = char_pair.second;
  461. enum llama_gretype type = last_sym_start < rule.size()
  462. ? LLAMA_GRETYPE_CHAR_ALT
  463. : start_type;
  464. rule.push_back({type, char_pair.first});
  465. if (pos[0] == '-' && pos[1] != ']') {
  466. if (!pos[1]) {
  467. throw std::runtime_error("unexpected end of input");
  468. }
  469. auto endchar_pair = parse_char(pos + 1);
  470. pos = endchar_pair.second;
  471. rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
  472. }
  473. }
  474. pos = parse_space(pos + 1, is_nested);
  475. } else if (*pos == '<' || *pos == '!') { // token
  476. auto type = LLAMA_GRETYPE_TOKEN;
  477. if (*pos == '!') { // token inverse
  478. type = LLAMA_GRETYPE_TOKEN_NOT;
  479. pos++;
  480. }
  481. auto token_pair = parse_token(vocab, pos);
  482. const char * token_end = token_pair.second;
  483. last_sym_start = rule.size();
  484. rule.push_back({type, token_pair.first});
  485. pos = parse_space(token_end, is_nested);
  486. } else if (is_word_char(*pos)) { // rule reference
  487. const char * name_end = parse_name(pos);
  488. uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
  489. pos = parse_space(name_end, is_nested);
  490. last_sym_start = rule.size();
  491. rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
  492. } else if (*pos == '(') { // grouping
  493. // parse nested alternates into synthesized rule
  494. pos = parse_space(pos + 1, true);
  495. uint32_t sub_rule_id = generate_symbol_id(rule_name);
  496. pos = parse_alternates(pos, rule_name, sub_rule_id, true);
  497. last_sym_start = rule.size();
  498. // output reference to synthesized rule
  499. rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
  500. if (*pos != ')') {
  501. throw std::runtime_error(std::string("expecting ')' at ") + pos);
  502. }
  503. pos = parse_space(pos + 1, is_nested);
  504. } else if (*pos == '.') { // any char
  505. last_sym_start = rule.size();
  506. rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
  507. pos = parse_space(pos + 1, is_nested);
  508. } else if (*pos == '*') {
  509. pos = parse_space(pos + 1, is_nested);
  510. handle_repetitions(0, -1);
  511. } else if (*pos == '+') {
  512. pos = parse_space(pos + 1, is_nested);
  513. handle_repetitions(1, -1);
  514. } else if (*pos == '?') {
  515. pos = parse_space(pos + 1, is_nested);
  516. handle_repetitions(0, 1);
  517. } else if (*pos == '{') {
  518. pos = parse_space(pos + 1, is_nested);
  519. if (!is_digit_char(*pos)) {
  520. throw std::runtime_error(std::string("expecting an int at ") + pos);
  521. }
  522. const char * int_end = parse_int(pos);
  523. uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
  524. pos = parse_space(int_end, is_nested);
  525. uint64_t max_times = UINT64_MAX; // default: no max limit
  526. if (*pos == '}') {
  527. max_times = min_times;
  528. pos = parse_space(pos + 1, is_nested);
  529. } else if (*pos == ',') {
  530. pos = parse_space(pos + 1, is_nested);
  531. if (is_digit_char(*pos)) {
  532. const char * int_end = parse_int(pos);
  533. max_times = std::stoul(std::string(pos, int_end - pos));
  534. pos = parse_space(int_end, is_nested);
  535. }
  536. if (*pos != '}') {
  537. throw std::runtime_error(std::string("expecting '}' at ") + pos);
  538. }
  539. pos = parse_space(pos + 1, is_nested);
  540. } else {
  541. throw std::runtime_error(std::string("expecting ',' at ") + pos);
  542. }
  543. bool has_max = max_times != UINT64_MAX;
  544. if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
  545. throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
  546. }
  547. handle_repetitions(min_times, max_times);
  548. } else {
  549. break;
  550. }
  551. }
  552. return pos;
  553. }
  554. const char * llama_grammar_parser::parse_rule(const char * src) {
  555. const char * name_end = parse_name(src);
  556. const char * pos = parse_space(name_end, false);
  557. size_t name_len = name_end - src;
  558. uint32_t rule_id = get_symbol_id(src, name_len);
  559. const std::string name(src, name_len);
  560. if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
  561. throw std::runtime_error(std::string("expecting ::= at ") + pos);
  562. }
  563. pos = parse_space(pos + 3, true);
  564. pos = parse_alternates(pos, name, rule_id, false);
  565. if (*pos == '\r') {
  566. pos += pos[1] == '\n' ? 2 : 1;
  567. } else if (*pos == '\n') {
  568. pos++;
  569. } else if (*pos) {
  570. throw std::runtime_error(std::string("expecting newline or end at ") + pos);
  571. }
  572. return parse_space(pos, true);
  573. }
  574. bool llama_grammar_parser::parse(const char * src) {
  575. try {
  576. const char * pos = parse_space(src, true);
  577. while (*pos) {
  578. pos = parse_rule(pos);
  579. }
  580. // Validate the state to ensure that all rules are defined
  581. for (const auto & rule : rules) {
  582. if (rule.empty()) {
  583. throw std::runtime_error("Undefined rule");
  584. }
  585. for (const auto & elem : rule) {
  586. if (elem.type == LLAMA_GRETYPE_RULE_REF) {
  587. // Ensure that the rule at that location exists
  588. if (elem.value >= rules.size() || rules[elem.value].empty()) {
  589. // Get the name of the rule that is missing
  590. for (const auto & kv : symbol_ids) {
  591. if (kv.second == elem.value) {
  592. throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
  593. }
  594. }
  595. }
  596. }
  597. }
  598. }
  599. } catch (const std::exception & err) {
  600. fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
  601. rules.clear();
  602. return false;
  603. }
  604. return true;
  605. }
  606. void llama_grammar_parser::print(FILE * file) {
  607. try {
  608. std::map<uint32_t, std::string> symbol_id_names;
  609. for (const auto & kv : symbol_ids) {
  610. symbol_id_names[kv.second] = kv.first;
  611. }
  612. for (size_t i = 0, end = rules.size(); i < end; i++) {
  613. // fprintf(file, "%zu: ", i);
  614. // print_rule_binary(file, rules[i]);
  615. print_rule(file, uint32_t(i), rules[i], symbol_id_names);
  616. // fprintf(file, "\n");
  617. }
  618. } catch (const std::exception & err) {
  619. fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
  620. }
  621. }
  622. llama_grammar_stack llama_grammar_parser::c_rules() const {
  623. llama_grammar_stack ret;
  624. ret.reserve(rules.size());
  625. for (const auto & rule : rules) {
  626. ret.push_back(rule.data());
  627. }
  628. return ret;
  629. }
  630. // returns true iff pos points to the end of one of the definitions of a rule
  631. static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
  632. switch (pos->type) {
  633. case LLAMA_GRETYPE_END: return true; // NOLINT
  634. case LLAMA_GRETYPE_ALT: return true; // NOLINT
  635. default: return false;
  636. }
  637. }
  638. // returns true iff chr satisfies the char range at pos (regular or inverse range)
  639. // asserts that pos is pointing to a char range element
  640. static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
  641. const llama_grammar_element * pos,
  642. const uint32_t chr) {
  643. bool found = false;
  644. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  645. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
  646. do {
  647. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  648. // inclusive range, e.g. [a-z]
  649. found = found || (pos->value <= chr && chr <= pos[1].value);
  650. pos += 2;
  651. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  652. // Any character matches "."
  653. found = true;
  654. pos += 1;
  655. } else {
  656. // exact char match, e.g. [a] or "a"
  657. found = found || pos->value == chr;
  658. pos += 1;
  659. }
  660. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  661. return std::make_pair(found == is_positive_char, pos);
  662. }
  663. // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
  664. // range at pos (regular or inverse range)
  665. // asserts that pos is pointing to a char range element
  666. static bool llama_grammar_match_partial_char(
  667. const llama_grammar_element * pos,
  668. const llama_partial_utf8 partial_utf8) {
  669. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  670. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
  671. uint32_t partial_value = partial_utf8.value;
  672. int n_remain = partial_utf8.n_remain;
  673. // invalid sequence or 7-bit char split across 2 bytes (overlong)
  674. if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
  675. return false;
  676. }
  677. // range of possible code points this partial UTF-8 sequence could complete to
  678. uint32_t low = partial_value << (n_remain * 6);
  679. uint32_t high = low | ((1 << (n_remain * 6)) - 1);
  680. if (low == 0) {
  681. if (n_remain == 2) {
  682. low = 1 << 11;
  683. } else if (n_remain == 3) {
  684. low = 1 << 16;
  685. }
  686. }
  687. do {
  688. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  689. // inclusive range, e.g. [a-z]
  690. if (pos->value <= high && low <= pos[1].value) {
  691. return is_positive_char;
  692. }
  693. pos += 2;
  694. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  695. // Any character matches "."
  696. return true;
  697. } else {
  698. // exact char match, e.g. [a] or "a"
  699. if (low <= pos->value && pos->value <= high) {
  700. return is_positive_char;
  701. }
  702. pos += 1;
  703. }
  704. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  705. return !is_positive_char;
  706. }
  707. // returns true iff token matches the rule at pos (regular or inverse)
  708. // asserts that pos is pointing to a token element
  709. static bool llama_grammar_match_token(
  710. const llama_grammar_element * pos,
  711. const llama_token token) {
  712. GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
  713. if (pos->type == LLAMA_GRETYPE_TOKEN) {
  714. return pos->value == static_cast<uint32_t>(token);
  715. }
  716. if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
  717. return pos->value != static_cast<uint32_t>(token);
  718. }
  719. return false;
  720. }
  721. // transforms a grammar pushdown stack into N possible stacks, all ending
  722. // at a character range (terminal element)
  723. static void llama_grammar_advance_stack(
  724. const llama_grammar_rules & rules,
  725. const llama_grammar_stack & stack,
  726. llama_grammar_stacks & new_stacks) {
  727. if (stack.empty()) {
  728. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  729. new_stacks.emplace_back(stack);
  730. }
  731. return;
  732. }
  733. const llama_grammar_element * pos = stack.back();
  734. switch (pos->type) {
  735. case LLAMA_GRETYPE_RULE_REF: {
  736. const size_t rule_id = static_cast<size_t>(pos->value);
  737. const llama_grammar_element * subpos = rules[rule_id].data();
  738. do {
  739. // init new stack without the top (pos)
  740. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  741. if (!llama_grammar_is_end_of_sequence(pos + 1)) {
  742. // if this rule ref is followed by another element, add that to stack
  743. new_stack.push_back(pos + 1);
  744. }
  745. if (!llama_grammar_is_end_of_sequence(subpos)) {
  746. // if alternate is nonempty, add to stack
  747. new_stack.push_back(subpos);
  748. }
  749. llama_grammar_advance_stack(rules, new_stack, new_stacks);
  750. while (!llama_grammar_is_end_of_sequence(subpos)) {
  751. // scan to end of alternate def
  752. subpos++;
  753. }
  754. if (subpos->type == LLAMA_GRETYPE_ALT) {
  755. // there's another alternate def of this rule to process
  756. subpos++;
  757. } else {
  758. break;
  759. }
  760. } while (true);
  761. break;
  762. }
  763. case LLAMA_GRETYPE_CHAR:
  764. case LLAMA_GRETYPE_CHAR_NOT:
  765. case LLAMA_GRETYPE_CHAR_ANY:
  766. case LLAMA_GRETYPE_TOKEN:
  767. case LLAMA_GRETYPE_TOKEN_NOT:
  768. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  769. // only add the stack if it's not a duplicate of one we already have
  770. new_stacks.emplace_back(stack);
  771. }
  772. break;
  773. default:
  774. // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
  775. // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
  776. // those
  777. GGML_ABORT("fatal error");
  778. }
  779. }
  780. static llama_grammar_candidates llama_grammar_reject_candidates(
  781. const llama_grammar_rules & rules,
  782. const llama_grammar_stacks & stacks,
  783. const llama_grammar_candidates & candidates) {
  784. GGML_ASSERT(!stacks.empty()); // REVIEW
  785. if (candidates.empty()) {
  786. return {};
  787. }
  788. auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
  789. for (size_t i = 1, size = stacks.size(); i < size; ++i) {
  790. rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
  791. }
  792. return rejects;
  793. }
  794. static bool llama_grammar_detect_left_recursion(
  795. const llama_grammar_rules & rules,
  796. size_t rule_index,
  797. std::vector<bool> * rules_visited,
  798. std::vector<bool> * rules_in_progress,
  799. std::vector<bool> * rules_may_be_empty) {
  800. if ((*rules_in_progress)[rule_index]) {
  801. return true;
  802. }
  803. (*rules_in_progress)[rule_index] = true;
  804. const llama_grammar_rule & rule = rules[rule_index];
  805. // First check if the rule might produce the empty string. This could be done combined with the second
  806. // step but it's more readable as two steps.
  807. bool at_rule_start = true;
  808. for (size_t i = 0; i < rule.size(); i++) {
  809. if (llama_grammar_is_end_of_sequence(&rule[i])) {
  810. if (at_rule_start) {
  811. (*rules_may_be_empty)[rule_index] = true;
  812. break;
  813. }
  814. at_rule_start = true;
  815. } else {
  816. at_rule_start = false;
  817. }
  818. }
  819. // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
  820. // be empty)
  821. bool recurse_into_nonterminal = true;
  822. for (size_t i = 0; i < rule.size(); i++) {
  823. if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
  824. if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
  825. return true;
  826. }
  827. if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
  828. recurse_into_nonterminal = false;
  829. }
  830. } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
  831. recurse_into_nonterminal = true;
  832. } else {
  833. recurse_into_nonterminal = false;
  834. }
  835. }
  836. (*rules_in_progress)[rule_index] = false;
  837. (*rules_visited)[rule_index] = true;
  838. return false;
  839. }
  840. const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
  841. return grammar->rules;
  842. }
  843. llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
  844. return grammar->stacks;
  845. }
  846. static void llama_grammar_accept_chr(
  847. struct llama_grammar & grammar,
  848. const llama_grammar_stack & stack,
  849. uint32_t chr,
  850. llama_grammar_stacks & new_stacks) {
  851. if (stack.empty()) {
  852. return;
  853. }
  854. const llama_grammar_element * pos = stack.back();
  855. // ignore if this turns into a token
  856. if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
  857. return;
  858. }
  859. auto match = llama_grammar_match_char(pos, chr);
  860. if (match.first) {
  861. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  862. if (!llama_grammar_is_end_of_sequence(match.second)) {
  863. new_stack.push_back(match.second);
  864. }
  865. llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
  866. }
  867. }
  868. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
  869. llama_grammar_stacks stacks_new;
  870. stacks_new.reserve(grammar->stacks.size());
  871. for (const auto & stack : grammar->stacks) {
  872. llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
  873. }
  874. grammar->stacks = std::move(stacks_new);
  875. }
  876. llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
  877. const llama_grammar_rules & rules,
  878. const llama_grammar_stack & stack,
  879. const llama_grammar_candidates & candidates) {
  880. llama_grammar_candidates rejects;
  881. rejects.reserve(candidates.size());
  882. if (stack.empty()) {
  883. for (const auto & tok : candidates) {
  884. if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
  885. rejects.push_back(tok);
  886. }
  887. }
  888. return rejects;
  889. }
  890. const llama_grammar_element * stack_pos = stack.back();
  891. // if the top of the stack is a token rule, then we only need to check the token id
  892. if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
  893. for (const auto & tok : candidates) {
  894. if (*tok.code_points == 0) {
  895. // reached the end of a token consumed by char rules, reject iff it ended
  896. // in a partial response
  897. if (tok.partial_utf8.n_remain != 0) {
  898. rejects.push_back(tok);
  899. }
  900. } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
  901. rejects.push_back(tok);
  902. }
  903. }
  904. return rejects;
  905. }
  906. llama_grammar_candidates next_candidates;
  907. next_candidates.reserve(candidates.size());
  908. for (const auto & tok : candidates) {
  909. if (*tok.code_points == 0) {
  910. // reached end of full codepoints in token, reject iff it ended in a partial sequence
  911. // that cannot satisfy this position in grammar
  912. if (tok.partial_utf8.n_remain != 0 &&
  913. !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
  914. rejects.push_back(tok);
  915. }
  916. } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
  917. next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
  918. } else {
  919. rejects.push_back(tok);
  920. }
  921. }
  922. const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
  923. // update top of stack to next element, if any
  924. llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
  925. if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
  926. stack_after.push_back(stack_pos_after);
  927. }
  928. llama_grammar_stacks next_stacks;
  929. llama_grammar_advance_stack(rules, stack_after, next_stacks);
  930. auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
  931. for (const auto & tok : next_rejects) {
  932. rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
  933. }
  934. return rejects;
  935. }
  936. ////////////////////
  937. struct llama_grammar * llama_grammar_init_impl(
  938. const struct llama_vocab * vocab,
  939. const llama_grammar_element ** rules,
  940. size_t n_rules,
  941. size_t start_rule_index) {
  942. const llama_grammar_element * pos;
  943. // copy rule definitions into vectors
  944. llama_grammar_rules vec_rules(n_rules);
  945. for (size_t i = 0; i < n_rules; i++) {
  946. for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  947. vec_rules[i].push_back(*pos);
  948. }
  949. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  950. }
  951. // Check for left recursion
  952. std::vector<bool> rules_visited(n_rules);
  953. std::vector<bool> rules_in_progress(n_rules);
  954. std::vector<bool> rules_may_be_empty(n_rules);
  955. for (size_t i = 0; i < n_rules; i++) {
  956. if (rules_visited[i]) {
  957. continue;
  958. }
  959. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  960. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  961. return nullptr;
  962. }
  963. }
  964. // loop over alternates of start rule to build initial stacks
  965. llama_grammar_stacks stacks;
  966. pos = vec_rules[start_rule_index].data();
  967. do {
  968. llama_grammar_stack stack;
  969. if (!llama_grammar_is_end_of_sequence(pos)) {
  970. // if alternate is nonempty, add to stack
  971. stack.push_back(pos);
  972. }
  973. llama_grammar_advance_stack(vec_rules, stack, stacks);
  974. while (!llama_grammar_is_end_of_sequence(pos)) {
  975. // scan to end of alternate def
  976. pos++;
  977. }
  978. if (pos->type == LLAMA_GRETYPE_ALT) {
  979. // there's another alternate def of this rule to process
  980. pos++;
  981. } else {
  982. break;
  983. }
  984. } while (true);
  985. // Important: vec_rules has to be moved here, not copied, because stacks contains
  986. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  987. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  988. return new llama_grammar {
  989. vocab,
  990. std::move(vec_rules),
  991. std::move(stacks),
  992. /* .partial_utf8 = */ {},
  993. /* .lazy = */ false,
  994. /* .awaiting_trigger = */ false,
  995. /* .trigger_buffer = */ "",
  996. /* .trigger_buffer_positions = */ {},
  997. /* .trigger_tokens = */ {},
  998. /* .trigger_patterns = */ {},
  999. };
  1000. }
  1001. struct llama_grammar * llama_grammar_init_impl(
  1002. const struct llama_vocab * vocab,
  1003. const char * grammar_str,
  1004. const char * grammar_root,
  1005. bool lazy,
  1006. const char ** trigger_patterns,
  1007. size_t num_trigger_patterns,
  1008. const llama_token * trigger_tokens,
  1009. size_t num_trigger_tokens) {
  1010. llama_grammar_parser parser(vocab);
  1011. // if there is a grammar, parse it
  1012. // rules will be empty (default) if there are parse errors
  1013. if (!parser.parse(grammar_str) || parser.rules.empty()) {
  1014. fprintf(stderr, "%s: failed to parse grammar\n", __func__);
  1015. return nullptr;
  1016. }
  1017. // Ensure that there is a "root" node.
  1018. if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
  1019. fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
  1020. return nullptr;
  1021. }
  1022. std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
  1023. const size_t n_rules = grammar_rules.size();
  1024. const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
  1025. const llama_grammar_element * pos;
  1026. // copy rule definitions into vectors
  1027. llama_grammar_rules vec_rules(n_rules);
  1028. for (size_t i = 0; i < n_rules; i++) {
  1029. for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  1030. vec_rules[i].push_back(*pos);
  1031. }
  1032. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  1033. }
  1034. // Check for left recursion
  1035. std::vector<bool> rules_visited(n_rules);
  1036. std::vector<bool> rules_in_progress(n_rules);
  1037. std::vector<bool> rules_may_be_empty(n_rules);
  1038. for (size_t i = 0; i < n_rules; i++) {
  1039. if (rules_visited[i]) {
  1040. continue;
  1041. }
  1042. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  1043. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  1044. return nullptr;
  1045. }
  1046. }
  1047. // loop over alternates of start rule to build initial stacks
  1048. llama_grammar_stacks stacks;
  1049. pos = vec_rules[start_rule_index].data();
  1050. do {
  1051. llama_grammar_stack stack;
  1052. if (!llama_grammar_is_end_of_sequence(pos)) {
  1053. // if alternate is nonempty, add to stack
  1054. stack.push_back(pos);
  1055. }
  1056. llama_grammar_advance_stack(vec_rules, stack, stacks);
  1057. while (!llama_grammar_is_end_of_sequence(pos)) {
  1058. // scan to end of alternate def
  1059. pos++;
  1060. }
  1061. if (pos->type == LLAMA_GRETYPE_ALT) {
  1062. // there's another alternate def of this rule to process
  1063. pos++;
  1064. } else {
  1065. break;
  1066. }
  1067. } while (true);
  1068. std::vector<llama_token> vec_trigger_tokens;
  1069. std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
  1070. for (size_t i = 0; i < num_trigger_tokens; i++) {
  1071. GGML_ASSERT(trigger_tokens != nullptr);
  1072. vec_trigger_tokens.push_back(trigger_tokens[i]);
  1073. }
  1074. for (size_t i = 0; i < num_trigger_patterns; i++) {
  1075. GGML_ASSERT(trigger_patterns != nullptr);
  1076. auto & trigger = vec_trigger_patterns.emplace_back();
  1077. trigger.pattern = trigger_patterns[i];
  1078. trigger.regex = std::regex(trigger.pattern);
  1079. }
  1080. // Important: vec_rules has to be moved here, not copied, because stacks contains
  1081. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  1082. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  1083. return new llama_grammar {
  1084. vocab,
  1085. std::move(vec_rules),
  1086. std::move(stacks),
  1087. /* .partial_utf8 = */ {},
  1088. /* .lazy = */ lazy,
  1089. /* .awaiting_trigger = */ lazy,
  1090. /* .trigger_buffer = */ "",
  1091. /* .trigger_buffer_positions = */ {},
  1092. std::move(vec_trigger_tokens),
  1093. std::move(vec_trigger_patterns),
  1094. };
  1095. }
  1096. void llama_grammar_free_impl(struct llama_grammar * grammar) {
  1097. if (grammar == nullptr) {
  1098. return;
  1099. }
  1100. delete grammar;
  1101. }
  1102. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
  1103. auto * result = new llama_grammar {
  1104. grammar.vocab,
  1105. grammar.rules,
  1106. grammar.stacks,
  1107. grammar.partial_utf8,
  1108. grammar.lazy,
  1109. grammar.awaiting_trigger,
  1110. grammar.trigger_buffer,
  1111. grammar.trigger_buffer_positions,
  1112. grammar.trigger_tokens,
  1113. grammar.trigger_patterns,
  1114. };
  1115. // redirect elements in stacks to point to new rules
  1116. for (size_t is = 0; is < result->stacks.size(); is++) {
  1117. for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
  1118. for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
  1119. for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
  1120. if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
  1121. result->stacks[is][ie] = &result->rules[ir0][ir1];
  1122. }
  1123. }
  1124. }
  1125. }
  1126. }
  1127. return result;
  1128. }
  1129. void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
  1130. GGML_ASSERT(grammar.vocab != nullptr);
  1131. if (grammar.awaiting_trigger) {
  1132. return;
  1133. }
  1134. bool allow_eog = false;
  1135. for (const auto & stack : grammar.stacks) {
  1136. if (stack.empty()) {
  1137. allow_eog = true;
  1138. break;
  1139. }
  1140. }
  1141. std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
  1142. candidates_decoded.reserve(cur_p->size);
  1143. llama_grammar_candidates candidates_grammar;
  1144. candidates_grammar.reserve(cur_p->size);
  1145. for (size_t i = 0; i < cur_p->size; ++i) {
  1146. const llama_token id = cur_p->data[i].id;
  1147. const std::string & piece = grammar.vocab->token_to_piece(id);
  1148. if (grammar.vocab->is_eog(id)) {
  1149. if (!allow_eog) {
  1150. cur_p->data[i].logit = -INFINITY;
  1151. }
  1152. } else if (piece.empty() || piece[0] == 0) {
  1153. cur_p->data[i].logit = -INFINITY;
  1154. } else {
  1155. candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
  1156. candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
  1157. }
  1158. }
  1159. const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
  1160. for (const auto & reject : rejects) {
  1161. cur_p->data[reject.index].logit = -INFINITY;
  1162. }
  1163. }
  1164. void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
  1165. GGML_ASSERT(grammar.vocab != nullptr);
  1166. const auto & piece = grammar.vocab->token_to_piece(token);
  1167. if (grammar.awaiting_trigger) {
  1168. if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
  1169. grammar.awaiting_trigger = false;
  1170. grammar.trigger_buffer.clear();
  1171. llama_grammar_accept_token(grammar, token, piece);
  1172. LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
  1173. return;
  1174. } else {
  1175. auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
  1176. grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
  1177. grammar.trigger_buffer += piece;
  1178. std::smatch match;
  1179. for (const auto & trigger_pattern : grammar.trigger_patterns) {
  1180. if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
  1181. grammar.awaiting_trigger = false;
  1182. // get from the first matched capturing group to the end of the string
  1183. size_t start = std::string::npos;
  1184. for (auto i = 1u; i < match.size(); i++) {
  1185. if (match.length(i) > 0) {
  1186. start = match.position(i);
  1187. break;
  1188. }
  1189. }
  1190. if (start == std::string::npos) {
  1191. start = match.position(0);
  1192. }
  1193. // replay tokens that overlap with [start, end)
  1194. for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
  1195. auto [tok_start, tok_end] = tok_pos;
  1196. if (tok_end <= start) {
  1197. continue;
  1198. }
  1199. size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
  1200. size_t piece_len = tok_end - piece_start;
  1201. auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
  1202. llama_grammar_accept_token(grammar, tok, tok_piece);
  1203. }
  1204. auto constrained_str = grammar.trigger_buffer.substr(start);
  1205. grammar.trigger_buffer.clear();
  1206. grammar.trigger_buffer_positions.clear();
  1207. LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
  1208. return;
  1209. }
  1210. }
  1211. LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
  1212. return;
  1213. }
  1214. }
  1215. if (grammar.vocab->is_eog(token)) {
  1216. for (const auto & stack : grammar.stacks) {
  1217. if (stack.empty()) {
  1218. return;
  1219. }
  1220. }
  1221. GGML_ABORT("fatal error");
  1222. }
  1223. llama_grammar_accept_token(grammar, token, piece);
  1224. }
  1225. void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
  1226. // Note terminating 0 in decoded string
  1227. const auto decoded = decode_utf8(piece, grammar.partial_utf8);
  1228. const auto & code_points = decoded.first;
  1229. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  1230. llama_grammar_accept(&grammar, *it);
  1231. }
  1232. grammar.partial_utf8 = decoded.second;
  1233. if (grammar.stacks.empty()) {
  1234. throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
  1235. }
  1236. }
  1237. void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
  1238. // Note terminating 0 in decoded string
  1239. const auto decoded = decode_utf8(piece, grammar.partial_utf8);
  1240. const auto & code_points = decoded.first;
  1241. llama_grammar_stacks stacks_new;
  1242. stacks_new.reserve(grammar.stacks.size());
  1243. for (const auto & stack : grammar.stacks) {
  1244. if (stack.empty()) {
  1245. continue;
  1246. }
  1247. const llama_grammar_element * pos = stack.back();
  1248. if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
  1249. if (llama_grammar_match_token(pos, token)) {
  1250. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  1251. if (!llama_grammar_is_end_of_sequence(pos + 1)) {
  1252. new_stack.push_back(pos + 1);
  1253. }
  1254. llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
  1255. }
  1256. } else {
  1257. llama_grammar_stacks current_stacks = {stack};
  1258. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  1259. llama_grammar_stacks next_stacks;
  1260. for (const auto & cur_stack : current_stacks) {
  1261. llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
  1262. }
  1263. current_stacks = std::move(next_stacks);
  1264. if (current_stacks.empty()) {
  1265. break;
  1266. }
  1267. }
  1268. for (auto & surviving_stack : current_stacks) {
  1269. if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
  1270. stacks_new.emplace_back(surviving_stack);
  1271. }
  1272. }
  1273. }
  1274. }
  1275. grammar.stacks = std::move(stacks_new);
  1276. grammar.partial_utf8 = decoded.second;
  1277. if (grammar.stacks.empty()) {
  1278. throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
  1279. }
  1280. }