llama-grammar.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #pragma once
  2. #include "llama-impl.h"
  3. #include <map>
  4. struct llama_vocab;
  5. // grammar element type
  6. enum llama_gretype {
  7. // end of rule definition
  8. LLAMA_GRETYPE_END = 0,
  9. // start of alternate definition for rule
  10. LLAMA_GRETYPE_ALT = 1,
  11. // non-terminal element: reference to rule
  12. LLAMA_GRETYPE_RULE_REF = 2,
  13. // terminal element: character (code point)
  14. LLAMA_GRETYPE_CHAR = 3,
  15. // inverse char(s) ([^a], [^a-b] [^abc])
  16. LLAMA_GRETYPE_CHAR_NOT = 4,
  17. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  18. // be an inclusive range ([a-z])
  19. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  20. // modifies a preceding LLAMA_GRETYPE_CHAR or
  21. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  22. LLAMA_GRETYPE_CHAR_ALT = 6,
  23. // any character (.)
  24. LLAMA_GRETYPE_CHAR_ANY = 7,
  25. };
  26. typedef struct llama_grammar_element {
  27. enum llama_gretype type;
  28. uint32_t value; // Unicode code point or rule ID
  29. } llama_grammar_element;
  30. struct llama_partial_utf8 {
  31. uint32_t value; // bit value so far (unshifted)
  32. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  33. };
  34. struct llama_grammar_candidate {
  35. size_t index;
  36. const uint32_t * code_points;
  37. llama_partial_utf8 partial_utf8;
  38. };
  39. using llama_grammar_rule = std::vector< llama_grammar_element>;
  40. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  41. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  42. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  43. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  44. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  45. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  46. // takes a set of possible pushdown stacks on a grammar, which are required to
  47. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  48. // produces the N possible stacks if the given char is accepted at those
  49. // positions
  50. void llama_grammar_accept(
  51. const llama_grammar_rules & rules,
  52. const llama_grammar_stacks & stacks,
  53. uint32_t chr,
  54. llama_grammar_stacks & stacks_new);
  55. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  56. const llama_grammar_rules & rules,
  57. const llama_grammar_stack & stack,
  58. const llama_grammar_candidates & candidates);
  59. struct llama_grammar_parser {
  60. std::map<std::string, uint32_t> symbol_ids;
  61. llama_grammar_rules rules;
  62. llama_grammar_stack c_rules() const;
  63. uint32_t get_symbol_id(const char * src, size_t len);
  64. uint32_t generate_symbol_id(const std::string & base_name);
  65. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  66. const char * parse_alternates(
  67. const char * src,
  68. const std::string & rule_name,
  69. uint32_t rule_id,
  70. bool is_nested);
  71. const char * parse_sequence(
  72. const char * src,
  73. const std::string & rule_name,
  74. llama_grammar_rule & rule,
  75. bool is_nested);
  76. const char * parse_rule(const char * src);
  77. bool parse(const char * src);
  78. void print(FILE * file);
  79. };
  80. struct llama_grammar {
  81. // note: allow null vocab for testing (not great)
  82. const llama_vocab * vocab;
  83. const llama_grammar_rules rules; // TODO: shared ptr
  84. llama_grammar_stacks stacks;
  85. // buffer for partially generated UTF-8 sequence from accepted tokens
  86. llama_partial_utf8 partial_utf8;
  87. };
  88. //
  89. // internal API
  90. //
  91. // note: needed for tests (not great)
  92. struct llama_grammar * llama_grammar_init_impl(
  93. const struct llama_vocab * vocab,
  94. const llama_grammar_element ** rules,
  95. size_t n_rules,
  96. size_t start_rule_index);
  97. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
  98. void llama_grammar_free_impl(struct llama_grammar * grammar);
  99. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  100. // TODO: move the API below as member functions of llama_grammar
  101. void llama_grammar_apply_impl(
  102. const struct llama_grammar & grammar,
  103. llama_token_data_array * cur_p);
  104. void llama_grammar_accept_impl(
  105. struct llama_grammar & grammar,
  106. llama_token token);