llama-grammar.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #pragma once
  2. #include "llama-impl.h"
  3. #include <map>
  4. struct llama_vocab;
  5. // grammar element type
  6. enum llama_gretype {
  7. // end of rule definition
  8. LLAMA_GRETYPE_END = 0,
  9. // start of alternate definition for rule
  10. LLAMA_GRETYPE_ALT = 1,
  11. // non-terminal element: reference to rule
  12. LLAMA_GRETYPE_RULE_REF = 2,
  13. // terminal element: character (code point)
  14. LLAMA_GRETYPE_CHAR = 3,
  15. // inverse char(s) ([^a], [^a-b] [^abc])
  16. LLAMA_GRETYPE_CHAR_NOT = 4,
  17. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  18. // be an inclusive range ([a-z])
  19. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  20. // modifies a preceding LLAMA_GRETYPE_CHAR or
  21. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  22. LLAMA_GRETYPE_CHAR_ALT = 6,
  23. // any character (.)
  24. LLAMA_GRETYPE_CHAR_ANY = 7,
  25. };
  26. typedef struct llama_grammar_element {
  27. enum llama_gretype type;
  28. uint32_t value; // Unicode code point or rule ID
  29. } llama_grammar_element;
  30. struct llama_partial_utf8 {
  31. uint32_t value; // bit value so far (unshifted)
  32. int n_remain; // num bytes remaining; -1 indicates invalid sequence
  33. };
  34. struct llama_grammar_candidate {
  35. size_t index;
  36. const uint32_t * code_points;
  37. llama_partial_utf8 partial_utf8;
  38. };
  39. using llama_grammar_rule = std::vector< llama_grammar_element>;
  40. using llama_grammar_stack = std::vector<const llama_grammar_element *>;
  41. using llama_grammar_rules = std::vector<llama_grammar_rule>;
  42. using llama_grammar_stacks = std::vector<llama_grammar_stack>;
  43. using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
  44. // TODO: remove, needed for tests atm
  45. const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
  46. llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
  47. // takes a set of possible pushdown stacks on a grammar, which are required to
  48. // be positioned at a character range (see `llama_grammar_advance_stack`), and
  49. // produces the N possible stacks if the given char is accepted at those
  50. // positions
  51. void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
  52. std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
  53. const llama_grammar_rules & rules,
  54. const llama_grammar_stack & stack,
  55. const llama_grammar_candidates & candidates);
  56. struct llama_grammar_parser {
  57. std::map<std::string, uint32_t> symbol_ids;
  58. llama_grammar_rules rules;
  59. llama_grammar_stack c_rules() const;
  60. uint32_t get_symbol_id(const char * src, size_t len);
  61. uint32_t generate_symbol_id(const std::string & base_name);
  62. void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
  63. const char * parse_alternates(
  64. const char * src,
  65. const std::string & rule_name,
  66. uint32_t rule_id,
  67. bool is_nested);
  68. const char * parse_sequence(
  69. const char * src,
  70. const std::string & rule_name,
  71. llama_grammar_rule & rule,
  72. bool is_nested);
  73. const char * parse_rule(const char * src);
  74. bool parse(const char * src);
  75. void print(FILE * file);
  76. };
  77. struct llama_grammar {
  78. // note: allow null vocab for testing (not great)
  79. const llama_vocab * vocab;
  80. const llama_grammar_rules rules; // TODO: shared ptr
  81. llama_grammar_stacks stacks;
  82. // buffer for partially generated UTF-8 sequence from accepted tokens
  83. llama_partial_utf8 partial_utf8;
  84. };
  85. //
  86. // internal API
  87. //
  88. // note: needed for tests (not great)
  89. struct llama_grammar * llama_grammar_init_impl(
  90. const struct llama_vocab * vocab,
  91. const llama_grammar_element ** rules,
  92. size_t n_rules,
  93. size_t start_rule_index);
  94. struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
  95. void llama_grammar_free_impl(struct llama_grammar * grammar);
  96. struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
  97. // TODO: move the API below as member functions of llama_grammar
  98. void llama_grammar_apply_impl(
  99. const struct llama_grammar & grammar,
  100. llama_token_data_array * cur_p);
  101. void llama_grammar_accept_impl(
  102. struct llama_grammar & grammar,
  103. llama_token token);