llama-grammar.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #pragma once
  2. #include "llama.h"
  3. #include <map>
  4. #include <regex>
  5. #include <string>
  6. #include <vector>
  7. struct llama_vocab;
  8. // grammar element type
  9. enum llama_gretype {
  10. // end of rule definition
  11. LLAMA_GRETYPE_END = 0,
  12. // start of alternate definition for rule
  13. LLAMA_GRETYPE_ALT = 1,
  14. // non-terminal element: reference to rule
  15. LLAMA_GRETYPE_RULE_REF = 2,
  16. // terminal element: character (code point)
  17. LLAMA_GRETYPE_CHAR = 3,
  18. // inverse char(s) ([^a], [^a-b] [^abc])
  19. LLAMA_GRETYPE_CHAR_NOT = 4,
  20. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  21. // be an inclusive range ([a-z])
  22. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  23. // modifies a preceding LLAMA_GRETYPE_CHAR or
  24. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  25. LLAMA_GRETYPE_CHAR_ALT = 6,
  26. // any character (.)
  27. LLAMA_GRETYPE_CHAR_ANY = 7,
  28. };
  29. typedef struct llama_grammar_element {
  30. enum llama_gretype type;
  31. uint32_t value; // Unicode code point or rule ID
  32. } llama_grammar_element;
  33. struct llama_partial_utf8 {
  34. uint32_t value; // bit value so far (unshifted)
  35. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  36. };
  37. struct llama_grammar_candidate {
  38. size_t index;
  39. const uint32_t * code_points;
  40. llama_partial_utf8 partial_utf8;
  41. };
  42. using llama_grammar_rule = std::vector< llama_grammar_element>;
  43. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  44. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  45. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  46. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  47. // TODO: remove, needed for tests atm
  48. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  49. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  50. // takes a set of possible pushdown stacks on a grammar, which are required to
  51. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  52. // produces the N possible stacks if the given char is accepted at those
  53. // positions
  54. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
  55. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  56. const llama_grammar_rules & rules,
  57. const llama_grammar_stack & stack,
  58. const llama_grammar_candidates & candidates);
  59. struct llama_grammar_parser {
  60. std::map<std::string, uint32_t> symbol_ids;
  61. llama_grammar_rules rules;
  62. llama_grammar_stack c_rules() const;
  63. uint32_t get_symbol_id(const char * src, size_t len);
  64. uint32_t generate_symbol_id(const std::string & base_name);
  65. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  66. const char * parse_alternates(
  67. const char * src,
  68. const std::string & rule_name,
  69. uint32_t rule_id,
  70. bool is_nested);
  71. const char * parse_sequence(
  72. const char * src,
  73. const std::string & rule_name,
  74. llama_grammar_rule & rule,
  75. bool is_nested);
  76. const char * parse_rule(const char * src);
  77. bool parse(const char * src);
  78. void print(FILE * file);
  79. };
  80. struct llama_grammar_trigger_pattern {
  81. std::string pattern;
  82. std::regex regex;
  83. };
  84. struct llama_grammar {
  85. // note: allow null vocab for testing (not great)
  86. const llama_vocab * vocab;
  87. const llama_grammar_rules rules; // TODO: shared ptr
  88. llama_grammar_stacks stacks;
  89. // buffer for partially generated UTF-8 sequence from accepted tokens
  90. llama_partial_utf8 partial_utf8;
  91. // lazy grammars wait for trigger words or tokens before constraining the sampling.
  92. // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
  93. // (useful e.g. for tool_choice=required)
  94. bool lazy = false;
  95. bool awaiting_trigger = false; // Initialized to true for lazy grammars only
  96. std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
  97. std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
  98. std::vector<llama_grammar_trigger_pattern>
  99. trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
  100. // string, and the grammar will be given the string from the first match group onwards.
  101. };
  102. //
  103. // internal API
  104. //
  105. // note: needed for tests (not great)
  106. struct llama_grammar * llama_grammar_init_impl(
  107. const struct llama_vocab * vocab,
  108. const llama_grammar_element ** rules,
  109. size_t n_rules,
  110. size_t start_rule_index);
  111. struct llama_grammar * llama_grammar_init_impl(
  112. const struct llama_vocab * vocab,
  113. const char * grammar_str,
  114. const char * grammar_root,
  115. bool lazy,
  116. const char ** trigger_patterns,
  117. size_t num_trigger_patterns,
  118. const llama_token * trigger_tokens,
  119. size_t num_trigger_tokens);
  120. void llama_grammar_free_impl(struct llama_grammar * grammar);
  121. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  122. // TODO: move the API below as member functions of llama_grammar
  123. void llama_grammar_apply_impl(
  124. const struct llama_grammar & grammar,
  125. llama_token_data_array * cur_p);
  126. void llama_grammar_accept_impl(
  127. struct llama_grammar & grammar,
  128. llama_token token);
  129. void llama_grammar_accept_str(
  130. struct llama_grammar & grammar,
  131. const std::string & piece);