1
0

llguidance.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include "sampling.h"
  2. #include "log.h"
  3. #ifdef LLAMA_USE_LLGUIDANCE
  4. # include "llguidance.h"
  5. # include <cmath>
  6. struct llama_sampler_llg {
  7. const llama_vocab * vocab;
  8. std::string grammar_kind;
  9. std::string grammar_data;
  10. LlgTokenizer * tokenizer;
  11. LlgMatcher * grammar;
  12. };
  13. static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
  14. const char * grammar_data) {
  15. LlgConstraintInit cinit;
  16. llg_constraint_init_set_defaults(&cinit, tokenizer);
  17. const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
  18. if (log_level && *log_level) {
  19. cinit.log_stderr_level = atoi(log_level);
  20. }
  21. auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
  22. if (llg_matcher_get_error(c)) {
  23. LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
  24. llg_free_matcher(c);
  25. return nullptr;
  26. }
  27. return c;
  28. }
  29. static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
  30. return "llguidance";
  31. }
  32. static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
  33. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  34. if (ctx->grammar) {
  35. llg_matcher_consume_token(ctx->grammar, token);
  36. }
  37. }
  38. static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
  39. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  40. if (ctx->grammar) {
  41. const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
  42. if (mask == nullptr) {
  43. if (llg_matcher_compute_mask(ctx->grammar) == 0) {
  44. mask = llg_matcher_get_mask(ctx->grammar);
  45. } else {
  46. LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
  47. llg_free_matcher(ctx->grammar);
  48. ctx->grammar = nullptr;
  49. return;
  50. }
  51. }
  52. for (size_t i = 0; i < cur_p->size; ++i) {
  53. auto token = cur_p->data[i].id;
  54. if ((mask[token / 32] & (1 << (token % 32))) == 0) {
  55. cur_p->data[i].logit = -INFINITY;
  56. }
  57. }
  58. }
  59. }
  60. static void llama_sampler_llg_reset(llama_sampler * smpl) {
  61. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  62. if (ctx->grammar) {
  63. llg_matcher_reset(ctx->grammar);
  64. }
  65. }
  66. static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
  67. const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
  68. auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
  69. // copy the state
  70. {
  71. auto * result_ctx = (llama_sampler_llg *) result->ctx;
  72. if (ctx->grammar) {
  73. result_ctx->grammar_kind = ctx->grammar_kind;
  74. result_ctx->grammar_data = ctx->grammar_data;
  75. result_ctx->grammar = llg_clone_matcher(ctx->grammar);
  76. result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
  77. }
  78. }
  79. return result;
  80. }
  81. static void llama_sampler_llg_free(llama_sampler * smpl) {
  82. const auto * ctx = (llama_sampler_llg *) smpl->ctx;
  83. if (ctx->grammar) {
  84. llg_free_matcher(ctx->grammar);
  85. llg_free_tokenizer(ctx->tokenizer);
  86. }
  87. delete ctx;
  88. }
  89. static llama_sampler_i llama_sampler_llg_i = {
  90. /* .name = */ llama_sampler_llg_name,
  91. /* .accept = */ llama_sampler_llg_accept_impl,
  92. /* .apply = */ llama_sampler_llg_apply,
  93. /* .reset = */ llama_sampler_llg_reset,
  94. /* .clone = */ llama_sampler_llg_clone,
  95. /* .free = */ llama_sampler_llg_free,
  96. /* .backend_init = */ NULL,
  97. /* .backend_accept = */ NULL,
  98. /* .backend_apply = */ NULL,
  99. /* .backend_set_input = */ NULL,
  100. };
  101. static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
  102. uint32_t * output_tokens, size_t output_tokens_len) {
  103. const llama_vocab * vocab = (const llama_vocab *) user_data;
  104. int r = 0;
  105. try {
  106. r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
  107. true);
  108. } catch (const std::exception & e) {
  109. GGML_ABORT("llama_tokenize failed: %s\n", e.what());
  110. }
  111. if (r < 0) {
  112. return -r;
  113. }
  114. return r;
  115. }
  116. static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
  117. // TODO store the tokenizer in the vocab somehow
  118. static const llama_vocab * vocab_cache;
  119. static LlgTokenizer * tokenizer_cache;
  120. if (vocab_cache == vocab) {
  121. return llg_clone_tokenizer(tokenizer_cache);
  122. }
  123. auto tok_eos = llama_vocab_eot(vocab);
  124. if (tok_eos == LLAMA_TOKEN_NULL) {
  125. tok_eos = llama_vocab_eos(vocab);
  126. }
  127. size_t vocab_size = llama_vocab_n_tokens(vocab);
  128. auto token_lens = new uint32_t[vocab_size];
  129. // we typically have ~7 bytes per token; let's go on the safe side here
  130. auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
  131. auto token_bytes = new uint8_t[token_bytes_size];
  132. size_t offset = 0;
  133. for (size_t i = 0; i < vocab_size; i++) {
  134. size_t max_token = 1024;
  135. if (token_bytes_size - offset < max_token) {
  136. GGML_ABORT("token_bytes buffer too small\n");
  137. }
  138. llama_token token = i;
  139. auto dp = (char *) token_bytes + offset;
  140. auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
  141. if (size < 0) {
  142. GGML_ABORT("llama_detokenize failed\n");
  143. }
  144. if (size == 0) {
  145. size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
  146. if (size < 0) {
  147. GGML_ABORT("llama_detokenize failed\n");
  148. }
  149. if (size != 0) {
  150. *dp = '\xff'; // special token prefix marker
  151. size += 1;
  152. }
  153. }
  154. token_lens[i] = size;
  155. offset += size;
  156. }
  157. LlgTokenizerInit tinit = {
  158. /* .vocab_size = */ (uint32_t) vocab_size,
  159. /* .tok_eos = */ (uint32_t) tok_eos,
  160. /* .token_lens = */ token_lens,
  161. /* .token_bytes = */ token_bytes,
  162. /* .tokenizer_json = */ nullptr,
  163. /* .tokenize_assumes_string = */ true,
  164. /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
  165. /* .use_approximate_greedy_tokenize_fn = */ false,
  166. /* .tokenize_user_data = */ vocab,
  167. /* .slices = */ nullptr,
  168. };
  169. char error_buffer[1024];
  170. LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
  171. delete[] token_bytes;
  172. delete[] token_lens;
  173. if (tokenizer == nullptr) {
  174. LOG_ERR("llg tokenizer error: %s\n", error_buffer);
  175. return tokenizer;
  176. }
  177. if (tokenizer_cache) {
  178. llg_free_tokenizer(tokenizer_cache);
  179. }
  180. vocab_cache = vocab;
  181. tokenizer_cache = tokenizer;
  182. return llg_clone_tokenizer(tokenizer_cache);
  183. }
  184. llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
  185. const char * grammar_data) {
  186. auto * ctx = new llama_sampler_llg;
  187. if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
  188. auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
  189. *ctx = {
  190. /* .vocab = */ vocab,
  191. /* .grammar_kind = */ grammar_kind,
  192. /* .grammar_data = */ grammar_data,
  193. /* .tokenizer = */ tokenizer,
  194. /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
  195. };
  196. if (ctx->grammar) {
  197. GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
  198. llg_matcher_get_mask_byte_size(ctx->grammar));
  199. }
  200. } else {
  201. *ctx = {
  202. /* .vocab = */ vocab,
  203. /* .grammar_kind = */ {},
  204. /* .grammar_data = */ {},
  205. /* .tokenizer = */ nullptr,
  206. /* .grammar = */ nullptr,
  207. };
  208. }
  209. return llama_sampler_init(
  210. /* .iface = */ &llama_sampler_llg_i,
  211. /* .ctx = */ ctx);
  212. }
  213. #else
  214. llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
  215. LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
  216. return nullptr;
  217. }
  218. #endif // LLAMA_USE_LLGUIDANCE