llama-grammar.cpp 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138
  1. #include "llama-grammar.h"
  2. #include "llama-vocab.h"
  3. #include "llama-sampling.h"
  4. #include <cmath>
  5. #include <algorithm>
  6. #include <stdexcept>
  7. //
  8. // helpers
  9. //
  10. // NOTE: assumes valid utf8 (but checks for overrun)
  11. static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
  12. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  13. uint8_t first_byte = static_cast<uint8_t>(*src);
  14. uint8_t highbits = first_byte >> 4;
  15. int len = lookup[highbits];
  16. uint8_t mask = (1 << (8 - len)) - 1;
  17. uint32_t value = first_byte & mask;
  18. const char * end = src + len; // may overrun!
  19. const char * pos = src + 1;
  20. for ( ; pos < end && *pos; pos++) {
  21. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  22. }
  23. return std::make_pair(value, pos);
  24. }
  25. static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
  26. const std::string & src,
  27. llama_partial_utf8 partial_start) {
  28. static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
  29. const char * pos = src.c_str();
  30. std::vector<uint32_t> code_points;
  31. // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
  32. code_points.reserve(src.size() + 1);
  33. uint32_t value = partial_start.value;
  34. int n_remain = partial_start.n_remain;
  35. // continue previous decode, if applicable
  36. while (*pos != 0 && n_remain > 0) {
  37. uint8_t next_byte = static_cast<uint8_t>(*pos);
  38. if ((next_byte >> 6) != 2) {
  39. // invalid sequence, abort
  40. code_points.push_back(0);
  41. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
  42. }
  43. value = (value << 6) + (next_byte & 0x3F);
  44. ++pos;
  45. --n_remain;
  46. }
  47. if (partial_start.n_remain > 0 && n_remain == 0) {
  48. code_points.push_back(value);
  49. }
  50. // decode any subsequent utf-8 sequences, which may end in an incomplete one
  51. while (*pos != 0) {
  52. uint8_t first_byte = static_cast<uint8_t>(*pos);
  53. uint8_t highbits = first_byte >> 4;
  54. n_remain = lookup[highbits] - 1;
  55. if (n_remain < 0) {
  56. // invalid sequence, abort
  57. code_points.clear();
  58. code_points.push_back(0);
  59. return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
  60. }
  61. uint8_t mask = (1 << (7 - n_remain)) - 1;
  62. value = first_byte & mask;
  63. ++pos;
  64. while (*pos != 0 && n_remain > 0) {
  65. value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
  66. ++pos;
  67. --n_remain;
  68. }
  69. if (n_remain == 0) {
  70. code_points.push_back(value);
  71. }
  72. }
  73. code_points.push_back(0);
  74. return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
  75. }
  76. static bool is_digit_char(char c) {
  77. return '0' <= c && c <= '9';
  78. }
  79. static bool is_word_char(char c) {
  80. return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
  81. }
  82. static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
  83. const char * pos = src;
  84. const char * end = src + size;
  85. uint32_t value = 0;
  86. for ( ; pos < end && *pos; pos++) {
  87. value <<= 4;
  88. char c = *pos;
  89. if ('a' <= c && c <= 'f') {
  90. value += c - 'a' + 10;
  91. } else if ('A' <= c && c <= 'F') {
  92. value += c - 'A' + 10;
  93. } else if ('0' <= c && c <= '9') {
  94. value += c - '0';
  95. } else {
  96. break;
  97. }
  98. }
  99. if (pos != end) {
  100. throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
  101. }
  102. return std::make_pair(value, pos);
  103. }
  104. static const char * parse_space(const char * src, bool newline_ok) {
  105. const char * pos = src;
  106. while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
  107. (newline_ok && (*pos == '\r' || *pos == '\n'))) {
  108. if (*pos == '#') {
  109. while (*pos && *pos != '\r' && *pos != '\n') {
  110. pos++;
  111. }
  112. } else {
  113. pos++;
  114. }
  115. }
  116. return pos;
  117. }
  118. static const char * parse_name(const char * src) {
  119. const char * pos = src;
  120. while (is_word_char(*pos)) {
  121. pos++;
  122. }
  123. if (pos == src) {
  124. throw std::runtime_error(std::string("expecting name at ") + src);
  125. }
  126. return pos;
  127. }
  128. static const char * parse_int(const char * src) {
  129. const char * pos = src;
  130. while (is_digit_char(*pos)) {
  131. pos++;
  132. }
  133. if (pos == src) {
  134. throw std::runtime_error(std::string("expecting integer at ") + src);
  135. }
  136. return pos;
  137. }
  138. static std::pair<uint32_t, const char *> parse_char(const char * src) {
  139. if (*src == '\\') {
  140. switch (src[1]) {
  141. case 'x': return parse_hex(src + 2, 2);
  142. case 'u': return parse_hex(src + 2, 4);
  143. case 'U': return parse_hex(src + 2, 8);
  144. case 't': return std::make_pair('\t', src + 2);
  145. case 'r': return std::make_pair('\r', src + 2);
  146. case 'n': return std::make_pair('\n', src + 2);
  147. case '\\':
  148. case '"':
  149. case '[':
  150. case ']':
  151. return std::make_pair(src[1], src + 2);
  152. default:
  153. throw std::runtime_error(std::string("unknown escape at ") + src);
  154. }
  155. } else if (*src) {
  156. return decode_utf8(src);
  157. }
  158. throw std::runtime_error("unexpected end of input");
  159. }
  160. static void print_grammar_char(FILE * file, uint32_t c) {
  161. if (0x20 <= c && c <= 0x7f) {
  162. fprintf(file, "%c", static_cast<char>(c));
  163. } else {
  164. // cop out of encoding UTF-8
  165. fprintf(file, "<U+%04X>", c);
  166. }
  167. }
  168. static bool is_char_element(llama_grammar_element elem) {
  169. switch (elem.type) {
  170. case LLAMA_GRETYPE_CHAR: return true;
  171. case LLAMA_GRETYPE_CHAR_NOT: return true;
  172. case LLAMA_GRETYPE_CHAR_ALT: return true;
  173. case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
  174. case LLAMA_GRETYPE_CHAR_ANY: return true;
  175. default: return false;
  176. }
  177. }
  178. static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
  179. for (auto elem : rule) {
  180. switch (elem.type) {
  181. case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
  182. case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
  183. case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
  184. case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
  185. case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
  186. case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
  187. case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
  188. case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
  189. }
  190. switch (elem.type) {
  191. case LLAMA_GRETYPE_END:
  192. case LLAMA_GRETYPE_ALT:
  193. case LLAMA_GRETYPE_RULE_REF:
  194. fprintf(file, "(%u) ", elem.value);
  195. break;
  196. case LLAMA_GRETYPE_CHAR:
  197. case LLAMA_GRETYPE_CHAR_NOT:
  198. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  199. case LLAMA_GRETYPE_CHAR_ALT:
  200. case LLAMA_GRETYPE_CHAR_ANY:
  201. fprintf(file, "(\"");
  202. print_grammar_char(file, elem.value);
  203. fprintf(file, "\") ");
  204. break;
  205. }
  206. }
  207. fprintf(file, "\n");
  208. }
  209. static void print_rule(
  210. FILE * file,
  211. uint32_t rule_id,
  212. const llama_grammar_rule & rule,
  213. const std::map<uint32_t, std::string> & symbol_id_names) {
  214. if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
  215. throw std::runtime_error(
  216. "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
  217. }
  218. fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
  219. for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
  220. llama_grammar_element elem = rule[i];
  221. switch (elem.type) {
  222. case LLAMA_GRETYPE_END:
  223. throw std::runtime_error(
  224. "unexpected end of rule: " + std::to_string(rule_id) + "," +
  225. std::to_string(i));
  226. case LLAMA_GRETYPE_ALT:
  227. fprintf(file, "| ");
  228. break;
  229. case LLAMA_GRETYPE_RULE_REF:
  230. fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
  231. break;
  232. case LLAMA_GRETYPE_CHAR:
  233. fprintf(file, "[");
  234. print_grammar_char(file, elem.value);
  235. break;
  236. case LLAMA_GRETYPE_CHAR_NOT:
  237. fprintf(file, "[^");
  238. print_grammar_char(file, elem.value);
  239. break;
  240. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  241. if (i == 0 || !is_char_element(rule[i - 1])) {
  242. throw std::runtime_error(
  243. "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
  244. std::to_string(rule_id) + "," + std::to_string(i));
  245. }
  246. fprintf(file, "-");
  247. print_grammar_char(file, elem.value);
  248. break;
  249. case LLAMA_GRETYPE_CHAR_ALT:
  250. if (i == 0 || !is_char_element(rule[i - 1])) {
  251. throw std::runtime_error(
  252. "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
  253. std::to_string(rule_id) + "," + std::to_string(i));
  254. }
  255. print_grammar_char(file, elem.value);
  256. break;
  257. case LLAMA_GRETYPE_CHAR_ANY:
  258. fprintf(file, ".");
  259. break;
  260. }
  261. if (is_char_element(elem)) {
  262. switch (rule[i + 1].type) {
  263. case LLAMA_GRETYPE_CHAR_ALT:
  264. case LLAMA_GRETYPE_CHAR_RNG_UPPER:
  265. case LLAMA_GRETYPE_CHAR_ANY:
  266. break;
  267. default:
  268. fprintf(file, "] ");
  269. }
  270. }
  271. }
  272. fprintf(file, "\n");
  273. }
  274. //
  275. // implementation
  276. //
  277. uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
  278. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  279. auto result = symbol_ids.emplace(std::string(src, len), next_id);
  280. return result.first->second;
  281. }
  282. uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
  283. uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
  284. symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
  285. return next_id;
  286. }
  287. void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
  288. if (rules.size() <= rule_id) {
  289. rules.resize(rule_id + 1);
  290. }
  291. rules[rule_id] = rule;
  292. }
  293. const char * llama_grammar_parser::parse_alternates(
  294. const char * src,
  295. const std::string & rule_name,
  296. uint32_t rule_id,
  297. bool is_nested) {
  298. llama_grammar_rule rule;
  299. const char * pos = parse_sequence(src, rule_name, rule, is_nested);
  300. while (*pos == '|') {
  301. rule.push_back({LLAMA_GRETYPE_ALT, 0});
  302. pos = parse_space(pos + 1, true);
  303. pos = parse_sequence(pos, rule_name, rule, is_nested);
  304. }
  305. rule.push_back({LLAMA_GRETYPE_END, 0});
  306. add_rule(rule_id, rule);
  307. return pos;
  308. }
  309. const char * llama_grammar_parser::parse_sequence(
  310. const char * src,
  311. const std::string & rule_name,
  312. llama_grammar_rule & rule,
  313. bool is_nested) {
  314. size_t last_sym_start = rule.size();
  315. const char * pos = src;
  316. auto handle_repetitions = [&](int min_times, int max_times) {
  317. if (last_sym_start == rule.size()) {
  318. throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
  319. }
  320. // apply transformation to previous symbol (last_sym_start to end) according to
  321. // the following rewrite rules:
  322. // S{m,n} --> S S S (m times) S'(n-m)
  323. // S'(x) ::= S S'(x-1) |
  324. // (... n-m definitions of these S' rules ...)
  325. // S'(1) ::= S |
  326. // S{m,} --> S S S (m times) S'
  327. // S' ::= S S' |
  328. // S* --> S{0,}
  329. // --> S' ::= S S' |
  330. // S+ --> S{1,}
  331. // --> S S'
  332. // S' ::= S S' |
  333. // S? --> S{0,1}
  334. // --> S'
  335. // S' ::= S |
  336. llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
  337. if (min_times == 0) {
  338. rule.resize(last_sym_start);
  339. } else {
  340. // Repeat the previous elements (min_times - 1) times
  341. for (int i = 1; i < min_times; i++) {
  342. rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
  343. }
  344. }
  345. uint32_t last_rec_rule_id = 0;
  346. auto n_opt = max_times < 0 ? 1 : max_times - min_times;
  347. llama_grammar_rule rec_rule(prev_rule);
  348. for (int i = 0; i < n_opt; i++) {
  349. rec_rule.resize(prev_rule.size());
  350. uint32_t rec_rule_id = generate_symbol_id( rule_name);
  351. if (i > 0 || max_times < 0) {
  352. rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
  353. }
  354. rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
  355. rec_rule.push_back({LLAMA_GRETYPE_END, 0});
  356. add_rule( rec_rule_id, rec_rule);
  357. last_rec_rule_id = rec_rule_id;
  358. }
  359. if (n_opt > 0) {
  360. rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
  361. }
  362. };
  363. while (*pos) {
  364. if (*pos == '"') { // literal string
  365. pos++;
  366. last_sym_start = rule.size();
  367. while (*pos != '"') {
  368. if (!*pos) {
  369. throw std::runtime_error("unexpected end of input");
  370. }
  371. auto char_pair = parse_char(pos);
  372. pos = char_pair.second;
  373. rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
  374. }
  375. pos = parse_space(pos + 1, is_nested);
  376. } else if (*pos == '[') { // char range(s)
  377. pos++;
  378. enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
  379. if (*pos == '^') {
  380. pos++;
  381. start_type = LLAMA_GRETYPE_CHAR_NOT;
  382. }
  383. last_sym_start = rule.size();
  384. while (*pos != ']') {
  385. if (!*pos) {
  386. throw std::runtime_error("unexpected end of input");
  387. }
  388. auto char_pair = parse_char(pos);
  389. pos = char_pair.second;
  390. enum llama_gretype type = last_sym_start < rule.size()
  391. ? LLAMA_GRETYPE_CHAR_ALT
  392. : start_type;
  393. rule.push_back({type, char_pair.first});
  394. if (pos[0] == '-' && pos[1] != ']') {
  395. if (!pos[1]) {
  396. throw std::runtime_error("unexpected end of input");
  397. }
  398. auto endchar_pair = parse_char(pos + 1);
  399. pos = endchar_pair.second;
  400. rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
  401. }
  402. }
  403. pos = parse_space(pos + 1, is_nested);
  404. } else if (is_word_char(*pos)) { // rule reference
  405. const char * name_end = parse_name(pos);
  406. uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
  407. pos = parse_space(name_end, is_nested);
  408. last_sym_start = rule.size();
  409. rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
  410. } else if (*pos == '(') { // grouping
  411. // parse nested alternates into synthesized rule
  412. pos = parse_space(pos + 1, true);
  413. uint32_t sub_rule_id = generate_symbol_id(rule_name);
  414. pos = parse_alternates(pos, rule_name, sub_rule_id, true);
  415. last_sym_start = rule.size();
  416. // output reference to synthesized rule
  417. rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
  418. if (*pos != ')') {
  419. throw std::runtime_error(std::string("expecting ')' at ") + pos);
  420. }
  421. pos = parse_space(pos + 1, is_nested);
  422. } else if (*pos == '.') { // any char
  423. last_sym_start = rule.size();
  424. rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
  425. pos = parse_space(pos + 1, is_nested);
  426. } else if (*pos == '*') {
  427. pos = parse_space(pos + 1, is_nested);
  428. handle_repetitions(0, -1);
  429. } else if (*pos == '+') {
  430. pos = parse_space(pos + 1, is_nested);
  431. handle_repetitions(1, -1);
  432. } else if (*pos == '?') {
  433. pos = parse_space(pos + 1, is_nested);
  434. handle_repetitions(0, 1);
  435. } else if (*pos == '{') {
  436. pos = parse_space(pos + 1, is_nested);
  437. if (!is_digit_char(*pos)) {
  438. throw std::runtime_error(std::string("expecting an int at ") + pos);
  439. }
  440. const char * int_end = parse_int(pos);
  441. int min_times = std::stoul(std::string(pos, int_end - pos));
  442. pos = parse_space(int_end, is_nested);
  443. int max_times = -1;
  444. if (*pos == '}') {
  445. max_times = min_times;
  446. pos = parse_space(pos + 1, is_nested);
  447. } else if (*pos == ',') {
  448. pos = parse_space(pos + 1, is_nested);
  449. if (is_digit_char(*pos)) {
  450. const char * int_end = parse_int(pos);
  451. max_times = std::stoul(std::string(pos, int_end - pos));
  452. pos = parse_space(int_end, is_nested);
  453. }
  454. if (*pos != '}') {
  455. throw std::runtime_error(std::string("expecting '}' at ") + pos);
  456. }
  457. pos = parse_space(pos + 1, is_nested);
  458. } else {
  459. throw std::runtime_error(std::string("expecting ',' at ") + pos);
  460. }
  461. handle_repetitions(min_times, max_times);
  462. } else {
  463. break;
  464. }
  465. }
  466. return pos;
  467. }
  468. const char * llama_grammar_parser::parse_rule(const char * src) {
  469. const char * name_end = parse_name(src);
  470. const char * pos = parse_space(name_end, false);
  471. size_t name_len = name_end - src;
  472. uint32_t rule_id = get_symbol_id(src, name_len);
  473. const std::string name(src, name_len);
  474. if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
  475. throw std::runtime_error(std::string("expecting ::= at ") + pos);
  476. }
  477. pos = parse_space(pos + 3, true);
  478. pos = parse_alternates(pos, name, rule_id, false);
  479. if (*pos == '\r') {
  480. pos += pos[1] == '\n' ? 2 : 1;
  481. } else if (*pos == '\n') {
  482. pos++;
  483. } else if (*pos) {
  484. throw std::runtime_error(std::string("expecting newline or end at ") + pos);
  485. }
  486. return parse_space(pos, true);
  487. }
  488. bool llama_grammar_parser::parse(const char * src) {
  489. try {
  490. const char * pos = parse_space(src, true);
  491. while (*pos) {
  492. pos = parse_rule(pos);
  493. }
  494. // Validate the state to ensure that all rules are defined
  495. for (const auto & rule : rules) {
  496. if (rule.empty()) {
  497. throw std::runtime_error("Undefined rule");
  498. }
  499. for (const auto & elem : rule) {
  500. if (elem.type == LLAMA_GRETYPE_RULE_REF) {
  501. // Ensure that the rule at that location exists
  502. if (elem.value >= rules.size() || rules[elem.value].empty()) {
  503. // Get the name of the rule that is missing
  504. for (const auto & kv : symbol_ids) {
  505. if (kv.second == elem.value) {
  506. throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
  507. }
  508. }
  509. }
  510. }
  511. }
  512. }
  513. } catch (const std::exception & err) {
  514. fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
  515. rules.clear();
  516. return false;
  517. }
  518. return true;
  519. }
  520. void llama_grammar_parser::print(FILE * file) {
  521. try {
  522. std::map<uint32_t, std::string> symbol_id_names;
  523. for (const auto & kv : symbol_ids) {
  524. symbol_id_names[kv.second] = kv.first;
  525. }
  526. for (size_t i = 0, end = rules.size(); i < end; i++) {
  527. // fprintf(file, "%zu: ", i);
  528. // print_rule_binary(file, rules[i]);
  529. print_rule(file, uint32_t(i), rules[i], symbol_id_names);
  530. // fprintf(file, "\n");
  531. }
  532. } catch (const std::exception & err) {
  533. fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
  534. }
  535. }
  536. llama_grammar_stack llama_grammar_parser::c_rules() const {
  537. llama_grammar_stack ret;
  538. ret.reserve(rules.size());
  539. for (const auto & rule : rules) {
  540. ret.push_back(rule.data());
  541. }
  542. return ret;
  543. }
  544. // returns true iff pos points to the end of one of the definitions of a rule
  545. static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
  546. switch (pos->type) {
  547. case LLAMA_GRETYPE_END: return true; // NOLINT
  548. case LLAMA_GRETYPE_ALT: return true; // NOLINT
  549. default: return false;
  550. }
  551. }
  552. // returns true iff chr satisfies the char range at pos (regular or inverse range)
  553. // asserts that pos is pointing to a char range element
  554. static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
  555. const llama_grammar_element * pos,
  556. const uint32_t chr) {
  557. bool found = false;
  558. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  559. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
  560. do {
  561. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  562. // inclusive range, e.g. [a-z]
  563. found = found || (pos->value <= chr && chr <= pos[1].value);
  564. pos += 2;
  565. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  566. // Any character matches "."
  567. found = true;
  568. pos += 1;
  569. } else {
  570. // exact char match, e.g. [a] or "a"
  571. found = found || pos->value == chr;
  572. pos += 1;
  573. }
  574. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  575. return std::make_pair(found == is_positive_char, pos);
  576. }
  577. // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
  578. // range at pos (regular or inverse range)
  579. // asserts that pos is pointing to a char range element
  580. static bool llama_grammar_match_partial_char(
  581. const llama_grammar_element * pos,
  582. const llama_partial_utf8 partial_utf8) {
  583. bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
  584. GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
  585. uint32_t partial_value = partial_utf8.value;
  586. int n_remain = partial_utf8.n_remain;
  587. // invalid sequence or 7-bit char split across 2 bytes (overlong)
  588. if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
  589. return false;
  590. }
  591. // range of possible code points this partial UTF-8 sequence could complete to
  592. uint32_t low = partial_value << (n_remain * 6);
  593. uint32_t high = low | ((1 << (n_remain * 6)) - 1);
  594. if (low == 0) {
  595. if (n_remain == 2) {
  596. low = 1 << 11;
  597. } else if (n_remain == 3) {
  598. low = 1 << 16;
  599. }
  600. }
  601. do {
  602. if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
  603. // inclusive range, e.g. [a-z]
  604. if (pos->value <= high && low <= pos[1].value) {
  605. return is_positive_char;
  606. }
  607. pos += 2;
  608. } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
  609. // Any character matches "."
  610. return true;
  611. } else {
  612. // exact char match, e.g. [a] or "a"
  613. if (low <= pos->value && pos->value <= high) {
  614. return is_positive_char;
  615. }
  616. pos += 1;
  617. }
  618. } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
  619. return !is_positive_char;
  620. }
  621. // transforms a grammar pushdown stack into N possible stacks, all ending
  622. // at a character range (terminal element)
  623. static void llama_grammar_advance_stack(
  624. const llama_grammar_rules & rules,
  625. const llama_grammar_stack & stack,
  626. llama_grammar_stacks & new_stacks) {
  627. if (stack.empty()) {
  628. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  629. new_stacks.emplace_back(stack);
  630. }
  631. return;
  632. }
  633. const llama_grammar_element * pos = stack.back();
  634. switch (pos->type) {
  635. case LLAMA_GRETYPE_RULE_REF: {
  636. const size_t rule_id = static_cast<size_t>(pos->value);
  637. const llama_grammar_element * subpos = rules[rule_id].data();
  638. do {
  639. // init new stack without the top (pos)
  640. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  641. if (!llama_grammar_is_end_of_sequence(pos + 1)) {
  642. // if this rule ref is followed by another element, add that to stack
  643. new_stack.push_back(pos + 1);
  644. }
  645. if (!llama_grammar_is_end_of_sequence(subpos)) {
  646. // if alternate is nonempty, add to stack
  647. new_stack.push_back(subpos);
  648. }
  649. llama_grammar_advance_stack(rules, new_stack, new_stacks);
  650. while (!llama_grammar_is_end_of_sequence(subpos)) {
  651. // scan to end of alternate def
  652. subpos++;
  653. }
  654. if (subpos->type == LLAMA_GRETYPE_ALT) {
  655. // there's another alternate def of this rule to process
  656. subpos++;
  657. } else {
  658. break;
  659. }
  660. } while (true);
  661. break;
  662. }
  663. case LLAMA_GRETYPE_CHAR:
  664. case LLAMA_GRETYPE_CHAR_NOT:
  665. case LLAMA_GRETYPE_CHAR_ANY:
  666. if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
  667. // only add the stack if it's not a duplicate of one we already have
  668. new_stacks.emplace_back(stack);
  669. }
  670. break;
  671. default:
  672. // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
  673. // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
  674. // those
  675. GGML_ABORT("fatal error");
  676. }
  677. }
  678. static llama_grammar_candidates llama_grammar_reject_candidates(
  679. const llama_grammar_rules & rules,
  680. const llama_grammar_stacks & stacks,
  681. const llama_grammar_candidates & candidates) {
  682. GGML_ASSERT(!stacks.empty()); // REVIEW
  683. if (candidates.empty()) {
  684. return {};
  685. }
  686. auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
  687. for (size_t i = 1, size = stacks.size(); i < size; ++i) {
  688. rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
  689. }
  690. return rejects;
  691. }
  692. static bool llama_grammar_detect_left_recursion(
  693. const llama_grammar_rules & rules,
  694. size_t rule_index,
  695. std::vector<bool> * rules_visited,
  696. std::vector<bool> * rules_in_progress,
  697. std::vector<bool> * rules_may_be_empty) {
  698. if ((*rules_in_progress)[rule_index]) {
  699. return true;
  700. }
  701. (*rules_in_progress)[rule_index] = true;
  702. const llama_grammar_rule & rule = rules[rule_index];
  703. // First check if the rule might produce the empty string. This could be done combined with the second
  704. // step but it's more readable as two steps.
  705. bool at_rule_start = true;
  706. for (size_t i = 0; i < rule.size(); i++) {
  707. if (llama_grammar_is_end_of_sequence(&rule[i])) {
  708. if (at_rule_start) {
  709. (*rules_may_be_empty)[rule_index] = true;
  710. break;
  711. }
  712. at_rule_start = true;
  713. } else {
  714. at_rule_start = false;
  715. }
  716. }
  717. // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
  718. // be empty)
  719. bool recurse_into_nonterminal = true;
  720. for (size_t i = 0; i < rule.size(); i++) {
  721. if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
  722. if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
  723. return true;
  724. }
  725. if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
  726. recurse_into_nonterminal = false;
  727. }
  728. } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
  729. recurse_into_nonterminal = true;
  730. } else {
  731. recurse_into_nonterminal = false;
  732. }
  733. }
  734. (*rules_in_progress)[rule_index] = false;
  735. (*rules_visited)[rule_index] = true;
  736. return false;
  737. }
  738. const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
  739. return grammar->rules;
  740. }
  741. llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
  742. return grammar->stacks;
  743. }
  744. void llama_grammar_accept(
  745. const llama_grammar_rules & rules,
  746. const llama_grammar_stacks & stacks,
  747. const uint32_t chr,
  748. llama_grammar_stacks & stacks_new) {
  749. stacks_new.clear();
  750. stacks_new.reserve(stacks.size());
  751. for (const auto & stack : stacks) {
  752. if (stack.empty()) {
  753. continue;
  754. }
  755. auto match = llama_grammar_match_char(stack.back(), chr);
  756. if (match.first) {
  757. const llama_grammar_element * pos = match.second;
  758. // update top of stack to next element, if any
  759. llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
  760. if (!llama_grammar_is_end_of_sequence(pos)) {
  761. new_stack.push_back(pos);
  762. }
  763. llama_grammar_advance_stack(rules, new_stack, stacks_new);
  764. }
  765. }
  766. }
  767. llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
  768. const llama_grammar_rules & rules,
  769. const llama_grammar_stack & stack,
  770. const llama_grammar_candidates & candidates) {
  771. llama_grammar_candidates rejects;
  772. rejects.reserve(candidates.size());
  773. if (stack.empty()) {
  774. for (const auto & tok : candidates) {
  775. if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
  776. rejects.push_back(tok);
  777. }
  778. }
  779. return rejects;
  780. }
  781. const llama_grammar_element * stack_pos = stack.back();
  782. llama_grammar_candidates next_candidates;
  783. next_candidates.reserve(candidates.size());
  784. for (const auto & tok : candidates) {
  785. if (*tok.code_points == 0) {
  786. // reached end of full codepoints in token, reject iff it ended in a partial sequence
  787. // that cannot satisfy this position in grammar
  788. if (tok.partial_utf8.n_remain != 0 &&
  789. !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
  790. rejects.push_back(tok);
  791. }
  792. } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
  793. next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
  794. } else {
  795. rejects.push_back(tok);
  796. }
  797. }
  798. const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
  799. // update top of stack to next element, if any
  800. llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
  801. if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
  802. stack_after.push_back(stack_pos_after);
  803. }
  804. llama_grammar_stacks next_stacks;
  805. llama_grammar_advance_stack(rules, stack_after, next_stacks);
  806. auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
  807. for (const auto & tok : next_rejects) {
  808. rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
  809. }
  810. return rejects;
  811. }
  812. ////////////////////
  813. struct llama_grammar * llama_grammar_init_impl(
  814. const struct llama_vocab * vocab,
  815. const llama_grammar_element ** rules,
  816. size_t n_rules,
  817. size_t start_rule_index) {
  818. const llama_grammar_element * pos;
  819. // copy rule definitions into vectors
  820. llama_grammar_rules vec_rules(n_rules);
  821. for (size_t i = 0; i < n_rules; i++) {
  822. for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  823. vec_rules[i].push_back(*pos);
  824. }
  825. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  826. }
  827. // Check for left recursion
  828. std::vector<bool> rules_visited(n_rules);
  829. std::vector<bool> rules_in_progress(n_rules);
  830. std::vector<bool> rules_may_be_empty(n_rules);
  831. for (size_t i = 0; i < n_rules; i++) {
  832. if (rules_visited[i]) {
  833. continue;
  834. }
  835. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  836. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  837. return nullptr;
  838. }
  839. }
  840. // loop over alternates of start rule to build initial stacks
  841. llama_grammar_stacks stacks;
  842. pos = vec_rules[start_rule_index].data();
  843. do {
  844. llama_grammar_stack stack;
  845. if (!llama_grammar_is_end_of_sequence(pos)) {
  846. // if alternate is nonempty, add to stack
  847. stack.push_back(pos);
  848. }
  849. llama_grammar_advance_stack(vec_rules, stack, stacks);
  850. while (!llama_grammar_is_end_of_sequence(pos)) {
  851. // scan to end of alternate def
  852. pos++;
  853. }
  854. if (pos->type == LLAMA_GRETYPE_ALT) {
  855. // there's another alternate def of this rule to process
  856. pos++;
  857. } else {
  858. break;
  859. }
  860. } while (true);
  861. // Important: vec_rules has to be moved here, not copied, because stacks contains
  862. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  863. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  864. return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
  865. }
  866. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
  867. llama_grammar_parser parser;
  868. // if there is a grammar, parse it
  869. if (!parser.parse(grammar_str)) {
  870. return nullptr;
  871. }
  872. // will be empty (default) if there are parse errors
  873. if (parser.rules.empty()) {
  874. fprintf(stderr, "%s: failed to parse grammar\n", __func__);
  875. return nullptr;
  876. }
  877. // Ensure that there is a "root" node.
  878. if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
  879. fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
  880. return nullptr;
  881. }
  882. std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
  883. const size_t n_rules = grammar_rules.size();
  884. const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
  885. const llama_grammar_element * pos;
  886. // copy rule definitions into vectors
  887. llama_grammar_rules vec_rules(n_rules);
  888. for (size_t i = 0; i < n_rules; i++) {
  889. for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
  890. vec_rules[i].push_back(*pos);
  891. }
  892. vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
  893. }
  894. // Check for left recursion
  895. std::vector<bool> rules_visited(n_rules);
  896. std::vector<bool> rules_in_progress(n_rules);
  897. std::vector<bool> rules_may_be_empty(n_rules);
  898. for (size_t i = 0; i < n_rules; i++) {
  899. if (rules_visited[i]) {
  900. continue;
  901. }
  902. if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
  903. LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
  904. return nullptr;
  905. }
  906. }
  907. // loop over alternates of start rule to build initial stacks
  908. llama_grammar_stacks stacks;
  909. pos = vec_rules[start_rule_index].data();
  910. do {
  911. llama_grammar_stack stack;
  912. if (!llama_grammar_is_end_of_sequence(pos)) {
  913. // if alternate is nonempty, add to stack
  914. stack.push_back(pos);
  915. }
  916. llama_grammar_advance_stack(vec_rules, stack, stacks);
  917. while (!llama_grammar_is_end_of_sequence(pos)) {
  918. // scan to end of alternate def
  919. pos++;
  920. }
  921. if (pos->type == LLAMA_GRETYPE_ALT) {
  922. // there's another alternate def of this rule to process
  923. pos++;
  924. } else {
  925. break;
  926. }
  927. } while (true);
  928. // Important: vec_rules has to be moved here, not copied, because stacks contains
  929. // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
  930. // then the pointers would be invalidated when the local vec_rules goes out of scope.
  931. return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
  932. }
  933. void llama_grammar_free_impl(struct llama_grammar * grammar) {
  934. if (grammar == nullptr) {
  935. return;
  936. }
  937. delete grammar;
  938. }
  939. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
  940. llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
  941. // redirect elements in stacks to point to new rules
  942. for (size_t is = 0; is < result->stacks.size(); is++) {
  943. for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
  944. for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
  945. for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
  946. if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
  947. result->stacks[is][ie] = &result->rules[ir0][ir1];
  948. }
  949. }
  950. }
  951. }
  952. }
  953. return result;
  954. }
  955. void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
  956. GGML_ASSERT(grammar.vocab != nullptr);
  957. bool allow_eog = false;
  958. for (const auto & stack : grammar.stacks) {
  959. if (stack.empty()) {
  960. allow_eog = true;
  961. break;
  962. }
  963. }
  964. std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
  965. candidates_decoded.reserve(cur_p->size);
  966. llama_grammar_candidates candidates_grammar;
  967. candidates_grammar.reserve(cur_p->size);
  968. for (size_t i = 0; i < cur_p->size; ++i) {
  969. const llama_token id = cur_p->data[i].id;
  970. const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
  971. if (llama_token_is_eog_impl(*grammar.vocab, id)) {
  972. if (!allow_eog) {
  973. cur_p->data[i].logit = -INFINITY;
  974. }
  975. } else if (piece.empty() || piece[0] == 0) {
  976. cur_p->data[i].logit = -INFINITY;
  977. } else {
  978. candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
  979. candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
  980. }
  981. }
  982. const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
  983. for (const auto & reject : rejects) {
  984. cur_p->data[reject.index].logit = -INFINITY;
  985. }
  986. }
  987. void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
  988. GGML_ASSERT(grammar.vocab != nullptr);
  989. if (llama_token_is_eog_impl(*grammar.vocab, token)) {
  990. for (const auto & stack : grammar.stacks) {
  991. if (stack.empty()) {
  992. return;
  993. }
  994. }
  995. GGML_ABORT("fatal error");
  996. }
  997. const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
  998. // Note terminating 0 in decoded string
  999. const auto decoded = decode_utf8(piece, grammar.partial_utf8);
  1000. const auto & code_points = decoded.first;
  1001. llama_grammar_stacks stacks_new;
  1002. for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
  1003. llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
  1004. grammar.stacks = std::move(stacks_new);
  1005. }
  1006. grammar.partial_utf8 = decoded.second;
  1007. GGML_ASSERT(!grammar.stacks.empty());
  1008. }