llguidance.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. #include "sampling.h"
  2. #include "log.h"
  3. #ifdef LLAMA_USE_LLGUIDANCE
  4. # include "llguidance.h"
  5. # include <cmath>
  6. struct llama_sampler_llg {
  7. const llama_vocab * vocab;
  8. std::string grammar_kind;
  9. std::string grammar_data;
  10. LlgTokenizer * tokenizer;
  11. LlgConstraint * grammar;
  12. LlgMaskResult llg_res;
  13. bool has_llg_res;
  14. };
  15. static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
  16. const char * grammar_data) {
  17. LlgConstraintInit cinit;
  18. llg_constraint_init_set_defaults(&cinit, tokenizer);
  19. const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
  20. if (log_level && *log_level) {
  21. cinit.log_stderr_level = atoi(log_level);
  22. }
  23. auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
  24. if (llg_get_error(c)) {
  25. LOG_ERR("llg error: %s\n", llg_get_error(c));
  26. llg_free_constraint(c);
  27. return nullptr;
  28. }
  29. return c;
  30. }
  31. static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
  32. return "llguidance";
  33. }
  34. static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
  35. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  36. if (ctx->grammar) {
  37. LlgCommitResult res;
  38. llg_commit_token(ctx->grammar, token, &res);
  39. ctx->has_llg_res = false;
  40. }
  41. }
  42. static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
  43. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  44. if (ctx->grammar) {
  45. if (!ctx->has_llg_res) {
  46. if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
  47. ctx->has_llg_res = true;
  48. } else {
  49. LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
  50. llg_free_constraint(ctx->grammar);
  51. ctx->grammar = nullptr;
  52. }
  53. }
  54. if (ctx->has_llg_res) {
  55. if (ctx->llg_res.is_stop) {
  56. for (size_t i = 0; i < cur_p->size; ++i) {
  57. if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
  58. cur_p->data[i].logit = -INFINITY;
  59. }
  60. }
  61. } else {
  62. const uint32_t * mask = ctx->llg_res.sample_mask;
  63. for (size_t i = 0; i < cur_p->size; ++i) {
  64. auto token = cur_p->data[i].id;
  65. if ((mask[token / 32] & (1 << (token % 32))) == 0) {
  66. cur_p->data[i].logit = -INFINITY;
  67. }
  68. }
  69. }
  70. }
  71. }
  72. }
  73. static void llama_sampler_llg_reset(llama_sampler * smpl) {
  74. auto * ctx = (llama_sampler_llg *) smpl->ctx;
  75. if (!ctx->grammar) {
  76. return;
  77. }
  78. auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
  79. llg_free_constraint(ctx->grammar);
  80. ctx->grammar = grammar_new;
  81. ctx->has_llg_res = false;
  82. }
  83. static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
  84. const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
  85. auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
  86. // copy the state
  87. {
  88. auto * result_ctx = (llama_sampler_llg *) result->ctx;
  89. if (ctx->grammar) {
  90. result_ctx->grammar_kind = ctx->grammar_kind;
  91. result_ctx->grammar_data = ctx->grammar_data;
  92. result_ctx->grammar = llg_clone_constraint(ctx->grammar);
  93. result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
  94. }
  95. }
  96. return result;
  97. }
  98. static void llama_sampler_llg_free(llama_sampler * smpl) {
  99. const auto * ctx = (llama_sampler_llg *) smpl->ctx;
  100. if (ctx->grammar) {
  101. llg_free_constraint(ctx->grammar);
  102. llg_free_tokenizer(ctx->tokenizer);
  103. }
  104. delete ctx;
  105. }
  106. static llama_sampler_i llama_sampler_llg_i = {
  107. /* .name = */ llama_sampler_llg_name,
  108. /* .accept = */ llama_sampler_llg_accept_impl,
  109. /* .apply = */ llama_sampler_llg_apply,
  110. /* .reset = */ llama_sampler_llg_reset,
  111. /* .clone = */ llama_sampler_llg_clone,
  112. /* .free = */ llama_sampler_llg_free,
  113. };
  114. static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
  115. uint32_t * output_tokens, size_t output_tokens_len) {
  116. const llama_vocab * vocab = (const llama_vocab *) user_data;
  117. int r = 0;
  118. try {
  119. r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
  120. true);
  121. } catch (const std::exception & e) {
  122. GGML_ABORT("llama_tokenize failed: %s\n", e.what());
  123. }
  124. if (r < 0) {
  125. return -r;
  126. }
  127. return r;
  128. }
  129. static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
  130. // TODO store the tokenizer in the vocab somehow
  131. static const llama_vocab * vocab_cache;
  132. static LlgTokenizer * tokenizer_cache;
  133. if (vocab_cache == vocab) {
  134. return llg_clone_tokenizer(tokenizer_cache);
  135. }
  136. auto tok_eos = llama_vocab_eot(vocab);
  137. if (tok_eos == LLAMA_TOKEN_NULL) {
  138. tok_eos = llama_vocab_eos(vocab);
  139. }
  140. size_t vocab_size = llama_vocab_n_tokens(vocab);
  141. auto token_lens = new uint32_t[vocab_size];
  142. // we typically have ~7 bytes per token; let's go on the safe side here
  143. auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
  144. auto token_bytes = new uint8_t[token_bytes_size];
  145. size_t offset = 0;
  146. for (size_t i = 0; i < vocab_size; i++) {
  147. size_t max_token = 1024;
  148. if (token_bytes_size - offset < max_token) {
  149. GGML_ABORT("token_bytes buffer too small\n");
  150. }
  151. llama_token token = i;
  152. auto dp = (char *) token_bytes + offset;
  153. auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
  154. if (size < 0) {
  155. GGML_ABORT("llama_detokenize failed\n");
  156. }
  157. if (size == 0) {
  158. size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
  159. if (size < 0) {
  160. GGML_ABORT("llama_detokenize failed\n");
  161. }
  162. if (size != 0) {
  163. *dp = '\xff'; // special token prefix marker
  164. size += 1;
  165. }
  166. }
  167. token_lens[i] = size;
  168. offset += size;
  169. }
  170. LlgTokenizerInit tinit = {
  171. /* .vocab_size = */ (uint32_t) vocab_size,
  172. /* .tok_eos = */ (uint32_t) tok_eos,
  173. /* .token_lens = */ token_lens,
  174. /* .token_bytes = */ token_bytes,
  175. /* .tokenizer_json = */ nullptr,
  176. /* .tokenize_assumes_string = */ true,
  177. /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
  178. /* .use_approximate_greedy_tokenize_fn = */ false,
  179. /* .tokenize_user_data = */ vocab,
  180. };
  181. char error_buffer[1024];
  182. LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
  183. delete[] token_bytes;
  184. delete[] token_lens;
  185. if (tokenizer == nullptr) {
  186. LOG_ERR("llg tokenizer error: %s\n", error_buffer);
  187. return tokenizer;
  188. }
  189. if (tokenizer_cache) {
  190. llg_free_tokenizer(tokenizer_cache);
  191. }
  192. vocab_cache = vocab;
  193. tokenizer_cache = tokenizer;
  194. return llg_clone_tokenizer(tokenizer_cache);
  195. }
  196. llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
  197. const char * grammar_data) {
  198. auto * ctx = new llama_sampler_llg;
  199. if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
  200. auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
  201. *ctx = {
  202. /* .vocab = */ vocab,
  203. /* .grammar_kind = */ grammar_kind,
  204. /* .grammar_data = */ grammar_data,
  205. /* .tokenizer = */ tokenizer,
  206. /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
  207. /* .llg_res = */ {},
  208. /* .has_llg_res = */ false,
  209. };
  210. } else {
  211. *ctx = {
  212. /* .vocab = */ vocab,
  213. /* .grammar_kind = */ {},
  214. /* .grammar_data = */ {},
  215. /* .tokenizer = */ nullptr,
  216. /* .grammar = */ nullptr,
  217. /* .llg_res = */ {},
  218. /* .has_llg_res = */ false,
  219. };
  220. }
  221. return llama_sampler_init(
  222. /* .iface = */ &llama_sampler_llg_i,
  223. /* .ctx = */ ctx
  224. );
  225. }
  226. #else
  227. llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
  228. LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
  229. return nullptr;
  230. }
  231. #endif // LLAMA_USE_LLGUIDANCE