llguidance.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #include "sampling.h"
  2. #include "log.h"
  3. #ifdef LLAMA_USE_LLGUIDANCE
  4. # include "llguidance.h"
  5. # include <cmath>
  6. struct llama_sampler_llg {
  7. const llama_vocab * vocab;
  8. std::string grammar_kind;
  9. std::string grammar_data;
  10. LlgTokenizer * tokenizer;
  11. LlgMatcher * grammar;
  12. };
  13. static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
  14. const char * grammar_data) {
  15. LlgConstraintInit cinit;
  16. llg_constraint_init_set_defaults(&cinit, tokenizer);
  17. const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
  18. if (log_level && *log_level) {
  19. cinit.log_stderr_level = atoi(log_level);
  20. }
  21. auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
  22. if (llg_matcher_get_error(c)) {
  23. LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
  24. llg_free_matcher(c);
  25. return nullptr;
  26. }
  27. return c;
  28. }
  29. static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
  30. return "llguidance";
  31. }
  32. static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
  33. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  34. if (ctx->grammar) {
  35. llg_matcher_consume_token(ctx->grammar, token);
  36. }
  37. }
  38. static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
  39. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  40. if (ctx->grammar) {
  41. const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
  42. if (mask == nullptr) {
  43. if (llg_matcher_compute_mask(ctx->grammar) == 0) {
  44. mask = llg_matcher_get_mask(ctx->grammar);
  45. } else {
  46. LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
  47. llg_free_matcher(ctx->grammar);
  48. ctx->grammar = nullptr;
  49. return;
  50. }
  51. }
  52. for (size_t i = 0; i < cur_p->size; ++i) {
  53. auto token = cur_p->data[i].id;
  54. if ((mask[token / 32] & (1 << (token % 32))) == 0) {
  55. cur_p->data[i].logit = -INFINITY;
  56. }
  57. }
  58. }
  59. }
  60. static void llama_sampler_llg_reset(llama_sampler * smpl) {
  61. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  62. if (ctx->grammar) {
  63. llg_matcher_reset(ctx->grammar);
  64. }
  65. }
  66. static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
  67. const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
  68. auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
  69. // copy the state
  70. {
  71. auto * result_ctx = (llama_sampler_llg *) result->ctx;
  72. if (ctx->grammar) {
  73. result_ctx->grammar_kind = ctx->grammar_kind;
  74. result_ctx->grammar_data = ctx->grammar_data;
  75. result_ctx->grammar = llg_clone_matcher(ctx->grammar);
  76. result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
  77. }
  78. }
  79. return result;
  80. }
  81. static void llama_sampler_llg_free(llama_sampler * smpl) {
  82. const auto * ctx = (llama_sampler_llg *) smpl->ctx;
  83. if (ctx->grammar) {
  84. llg_free_matcher(ctx->grammar);
  85. llg_free_tokenizer(ctx->tokenizer);
  86. }
  87. delete ctx;
  88. }
  89. static llama_sampler_i llama_sampler_llg_i = {
  90. /* .name = */ llama_sampler_llg_name,
  91. /* .accept = */ llama_sampler_llg_accept_impl,
  92. /* .apply = */ llama_sampler_llg_apply,
  93. /* .reset = */ llama_sampler_llg_reset,
  94. /* .clone = */ llama_sampler_llg_clone,
  95. /* .free = */ llama_sampler_llg_free,
  96. };
  97. static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
  98. uint32_t * output_tokens, size_t output_tokens_len) {
  99. const llama_vocab * vocab = (const llama_vocab *) user_data;
  100. int r = 0;
  101. try {
  102. r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
  103. true);
  104. } catch (const std::exception & e) {
  105. GGML_ABORT("llama_tokenize failed: %s\n", e.what());
  106. }
  107. if (r < 0) {
  108. return -r;
  109. }
  110. return r;
  111. }
  112. static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
  113. // TODO store the tokenizer in the vocab somehow
  114. static const llama_vocab * vocab_cache;
  115. static LlgTokenizer * tokenizer_cache;
  116. if (vocab_cache == vocab) {
  117. return llg_clone_tokenizer(tokenizer_cache);
  118. }
  119. auto tok_eos = llama_vocab_eot(vocab);
  120. if (tok_eos == LLAMA_TOKEN_NULL) {
  121. tok_eos = llama_vocab_eos(vocab);
  122. }
  123. size_t vocab_size = llama_vocab_n_tokens(vocab);
  124. auto token_lens = new uint32_t[vocab_size];
  125. // we typically have ~7 bytes per token; let's go on the safe side here
  126. auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
  127. auto token_bytes = new uint8_t[token_bytes_size];
  128. size_t offset = 0;
  129. for (size_t i = 0; i < vocab_size; i++) {
  130. size_t max_token = 1024;
  131. if (token_bytes_size - offset < max_token) {
  132. GGML_ABORT("token_bytes buffer too small\n");
  133. }
  134. llama_token token = i;
  135. auto dp = (char *) token_bytes + offset;
  136. auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
  137. if (size < 0) {
  138. GGML_ABORT("llama_detokenize failed\n");
  139. }
  140. if (size == 0) {
  141. size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
  142. if (size < 0) {
  143. GGML_ABORT("llama_detokenize failed\n");
  144. }
  145. if (size != 0) {
  146. *dp = '\xff'; // special token prefix marker
  147. size += 1;
  148. }
  149. }
  150. token_lens[i] = size;
  151. offset += size;
  152. }
  153. LlgTokenizerInit tinit = {
  154. /* .vocab_size = */ (uint32_t) vocab_size,
  155. /* .tok_eos = */ (uint32_t) tok_eos,
  156. /* .token_lens = */ token_lens,
  157. /* .token_bytes = */ token_bytes,
  158. /* .tokenizer_json = */ nullptr,
  159. /* .tokenize_assumes_string = */ true,
  160. /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
  161. /* .use_approximate_greedy_tokenize_fn = */ false,
  162. /* .tokenize_user_data = */ vocab,
  163. };
  164. char error_buffer[1024];
  165. LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
  166. delete[] token_bytes;
  167. delete[] token_lens;
  168. if (tokenizer == nullptr) {
  169. LOG_ERR("llg tokenizer error: %s\n", error_buffer);
  170. return tokenizer;
  171. }
  172. if (tokenizer_cache) {
  173. llg_free_tokenizer(tokenizer_cache);
  174. }
  175. vocab_cache = vocab;
  176. tokenizer_cache = tokenizer;
  177. return llg_clone_tokenizer(tokenizer_cache);
  178. }
  179. llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
  180. const char * grammar_data) {
  181. auto * ctx = new llama_sampler_llg;
  182. if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
  183. auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
  184. *ctx = {
  185. /* .vocab = */ vocab,
  186. /* .grammar_kind = */ grammar_kind,
  187. /* .grammar_data = */ grammar_data,
  188. /* .tokenizer = */ tokenizer,
  189. /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
  190. };
  191. if (ctx->grammar) {
  192. GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
  193. llg_matcher_get_mask_byte_size(ctx->grammar));
  194. }
  195. } else {
  196. *ctx = {
  197. /* .vocab = */ vocab,
  198. /* .grammar_kind = */ {},
  199. /* .grammar_data = */ {},
  200. /* .tokenizer = */ nullptr,
  201. /* .grammar = */ nullptr,
  202. };
  203. }
  204. return llama_sampler_init(
  205. /* .iface = */ &llama_sampler_llg_i,
  206. /* .ctx = */ ctx);
  207. }
  208. #else
  209. llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
  210. LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
  211. return nullptr;
  212. }
  213. #endif // LLAMA_USE_LLGUIDANCE