llama-grammar.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #pragma once
  2. #include "llama.h"
  3. #include <map>
  4. #include <string>
  5. #include <vector>
  6. struct llama_vocab;
  7. // grammar element type
  8. enum llama_gretype {
  9. // end of rule definition
  10. LLAMA_GRETYPE_END = 0,
  11. // start of alternate definition for rule
  12. LLAMA_GRETYPE_ALT = 1,
  13. // non-terminal element: reference to rule
  14. LLAMA_GRETYPE_RULE_REF = 2,
  15. // terminal element: character (code point)
  16. LLAMA_GRETYPE_CHAR = 3,
  17. // inverse char(s) ([^a], [^a-b] [^abc])
  18. LLAMA_GRETYPE_CHAR_NOT = 4,
  19. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  20. // be an inclusive range ([a-z])
  21. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  22. // modifies a preceding LLAMA_GRETYPE_CHAR or
  23. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  24. LLAMA_GRETYPE_CHAR_ALT = 6,
  25. // any character (.)
  26. LLAMA_GRETYPE_CHAR_ANY = 7,
  27. };
  28. typedef struct llama_grammar_element {
  29. enum llama_gretype type;
  30. uint32_t value; // Unicode code point or rule ID
  31. } llama_grammar_element;
  32. struct llama_partial_utf8 {
  33. uint32_t value; // bit value so far (unshifted)
  34. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  35. };
  36. struct llama_grammar_candidate {
  37. size_t index;
  38. const uint32_t * code_points;
  39. llama_partial_utf8 partial_utf8;
  40. };
  41. using llama_grammar_rule = std::vector< llama_grammar_element>;
  42. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  43. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  44. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  45. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  46. // TODO: remove, needed for tests atm
  47. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  48. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  49. // takes a set of possible pushdown stacks on a grammar, which are required to
  50. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  51. // produces the N possible stacks if the given char is accepted at those
  52. // positions
  53. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
  54. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  55. const llama_grammar_rules & rules,
  56. const llama_grammar_stack & stack,
  57. const llama_grammar_candidates & candidates);
  58. struct llama_grammar_parser {
  59. std::map<std::string, uint32_t> symbol_ids;
  60. llama_grammar_rules rules;
  61. llama_grammar_stack c_rules() const;
  62. uint32_t get_symbol_id(const char * src, size_t len);
  63. uint32_t generate_symbol_id(const std::string & base_name);
  64. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  65. const char * parse_alternates(
  66. const char * src,
  67. const std::string & rule_name,
  68. uint32_t rule_id,
  69. bool is_nested);
  70. const char * parse_sequence(
  71. const char * src,
  72. const std::string & rule_name,
  73. llama_grammar_rule & rule,
  74. bool is_nested);
  75. const char * parse_rule(const char * src);
  76. bool parse(const char * src);
  77. void print(FILE * file);
  78. };
  79. struct llama_grammar {
  80. // note: allow null vocab for testing (not great)
  81. const llama_vocab * vocab;
  82. const llama_grammar_rules rules; // TODO: shared ptr
  83. llama_grammar_stacks stacks;
  84. // buffer for partially generated UTF-8 sequence from accepted tokens
  85. llama_partial_utf8 partial_utf8;
  86. // lazy grammars wait for trigger words or tokens before constraining the sampling.
  87. // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
  88. // (useful e.g. for tool_choice=required)
  89. bool lazy = false;
  90. bool awaiting_trigger = false; // Initialized to true for lazy grammars only
  91. std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
  92. std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
  93. std::vector<std::string> trigger_words;
  94. };
  95. //
  96. // internal API
  97. //
  98. // note: needed for tests (not great)
  99. struct llama_grammar * llama_grammar_init_impl(
  100. const struct llama_vocab * vocab,
  101. const llama_grammar_element ** rules,
  102. size_t n_rules,
  103. size_t start_rule_index);
  104. struct llama_grammar * llama_grammar_init_impl(
  105. const struct llama_vocab * vocab,
  106. const char * grammar_str,
  107. const char * grammar_root,
  108. bool lazy,
  109. const char ** trigger_words,
  110. size_t num_trigger_words,
  111. const llama_token * trigger_tokens,
  112. size_t num_trigger_tokens);
  113. void llama_grammar_free_impl(struct llama_grammar * grammar);
  114. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  115. // TODO: move the API below as member functions of llama_grammar
  116. void llama_grammar_apply_impl(
  117. const struct llama_grammar & grammar,
  118. llama_token_data_array * cur_p);
  119. void llama_grammar_accept_impl(
  120. struct llama_grammar & grammar,
  121. llama_token token);
  122. void llama_grammar_accept_str(
  123. struct llama_grammar & grammar,
  124. const std::string & piece);