| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- #include "sampling.h"
- #include "log.h"
- #ifdef LLAMA_USE_LLGUIDANCE
- # include "llguidance.h"
- # include <cmath>
- struct llama_sampler_llg {
- const llama_vocab * vocab;
- std::string grammar_kind;
- std::string grammar_data;
- LlgTokenizer * tokenizer;
- LlgMatcher * grammar;
- };
- static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
- const char * grammar_data) {
- LlgConstraintInit cinit;
- llg_constraint_init_set_defaults(&cinit, tokenizer);
- const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
- if (log_level && *log_level) {
- cinit.log_stderr_level = atoi(log_level);
- }
- auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
- if (llg_matcher_get_error(c)) {
- LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
- llg_free_matcher(c);
- return nullptr;
- }
- return c;
- }
- static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
- return "llguidance";
- }
- static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
- auto * ctx = (llama_sampler_llg *) smpl->ctx;
- if (ctx->grammar) {
- llg_matcher_consume_token(ctx->grammar, token);
- }
- }
- static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
- auto * ctx = (llama_sampler_llg *) smpl->ctx;
- if (ctx->grammar) {
- const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
- if (mask == nullptr) {
- if (llg_matcher_compute_mask(ctx->grammar) == 0) {
- mask = llg_matcher_get_mask(ctx->grammar);
- } else {
- LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
- llg_free_matcher(ctx->grammar);
- ctx->grammar = nullptr;
- return;
- }
- }
- for (size_t i = 0; i < cur_p->size; ++i) {
- auto token = cur_p->data[i].id;
- if ((mask[token / 32] & (1 << (token % 32))) == 0) {
- cur_p->data[i].logit = -INFINITY;
- }
- }
- }
- }
- static void llama_sampler_llg_reset(llama_sampler * smpl) {
- auto * ctx = (llama_sampler_llg *) smpl->ctx;
- if (ctx->grammar) {
- llg_matcher_reset(ctx->grammar);
- }
- }
- static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
- const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
- auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
- // copy the state
- {
- auto * result_ctx = (llama_sampler_llg *) result->ctx;
- if (ctx->grammar) {
- result_ctx->grammar_kind = ctx->grammar_kind;
- result_ctx->grammar_data = ctx->grammar_data;
- result_ctx->grammar = llg_clone_matcher(ctx->grammar);
- result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
- }
- }
- return result;
- }
- static void llama_sampler_llg_free(llama_sampler * smpl) {
- const auto * ctx = (llama_sampler_llg *) smpl->ctx;
- if (ctx->grammar) {
- llg_free_matcher(ctx->grammar);
- llg_free_tokenizer(ctx->tokenizer);
- }
- delete ctx;
- }
- static llama_sampler_i llama_sampler_llg_i = {
- /* .name = */ llama_sampler_llg_name,
- /* .accept = */ llama_sampler_llg_accept_impl,
- /* .apply = */ llama_sampler_llg_apply,
- /* .reset = */ llama_sampler_llg_reset,
- /* .clone = */ llama_sampler_llg_clone,
- /* .free = */ llama_sampler_llg_free,
- };
- static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
- uint32_t * output_tokens, size_t output_tokens_len) {
- const llama_vocab * vocab = (const llama_vocab *) user_data;
- int r = 0;
- try {
- r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
- true);
- } catch (const std::exception & e) {
- GGML_ABORT("llama_tokenize failed: %s\n", e.what());
- }
- if (r < 0) {
- return -r;
- }
- return r;
- }
- static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
- // TODO store the tokenizer in the vocab somehow
- static const llama_vocab * vocab_cache;
- static LlgTokenizer * tokenizer_cache;
- if (vocab_cache == vocab) {
- return llg_clone_tokenizer(tokenizer_cache);
- }
- auto tok_eos = llama_vocab_eot(vocab);
- if (tok_eos == LLAMA_TOKEN_NULL) {
- tok_eos = llama_vocab_eos(vocab);
- }
- size_t vocab_size = llama_vocab_n_tokens(vocab);
- auto token_lens = new uint32_t[vocab_size];
- // we typically have ~7 bytes per token; let's go on the safe side here
- auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
- auto token_bytes = new uint8_t[token_bytes_size];
- size_t offset = 0;
- for (size_t i = 0; i < vocab_size; i++) {
- size_t max_token = 1024;
- if (token_bytes_size - offset < max_token) {
- GGML_ABORT("token_bytes buffer too small\n");
- }
- llama_token token = i;
- auto dp = (char *) token_bytes + offset;
- auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
- if (size < 0) {
- GGML_ABORT("llama_detokenize failed\n");
- }
- if (size == 0) {
- size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
- if (size < 0) {
- GGML_ABORT("llama_detokenize failed\n");
- }
- if (size != 0) {
- *dp = '\xff'; // special token prefix marker
- size += 1;
- }
- }
- token_lens[i] = size;
- offset += size;
- }
- LlgTokenizerInit tinit = {
- /* .vocab_size = */ (uint32_t) vocab_size,
- /* .tok_eos = */ (uint32_t) tok_eos,
- /* .token_lens = */ token_lens,
- /* .token_bytes = */ token_bytes,
- /* .tokenizer_json = */ nullptr,
- /* .tokenize_assumes_string = */ true,
- /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
- /* .use_approximate_greedy_tokenize_fn = */ false,
- /* .tokenize_user_data = */ vocab,
- /* .slices = */ nullptr,
- };
- char error_buffer[1024];
- LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
- delete[] token_bytes;
- delete[] token_lens;
- if (tokenizer == nullptr) {
- LOG_ERR("llg tokenizer error: %s\n", error_buffer);
- return tokenizer;
- }
- if (tokenizer_cache) {
- llg_free_tokenizer(tokenizer_cache);
- }
- vocab_cache = vocab;
- tokenizer_cache = tokenizer;
- return llg_clone_tokenizer(tokenizer_cache);
- }
- llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
- const char * grammar_data) {
- auto * ctx = new llama_sampler_llg;
- if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
- auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
- *ctx = {
- /* .vocab = */ vocab,
- /* .grammar_kind = */ grammar_kind,
- /* .grammar_data = */ grammar_data,
- /* .tokenizer = */ tokenizer,
- /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
- };
- if (ctx->grammar) {
- GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
- llg_matcher_get_mask_byte_size(ctx->grammar));
- }
- } else {
- *ctx = {
- /* .vocab = */ vocab,
- /* .grammar_kind = */ {},
- /* .grammar_data = */ {},
- /* .tokenizer = */ nullptr,
- /* .grammar = */ nullptr,
- };
- }
- return llama_sampler_init(
- /* .iface = */ &llama_sampler_llg_i,
- /* .ctx = */ ctx);
- }
- #else
- llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
- LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
- return nullptr;
- }
- #endif // LLAMA_USE_LLGUIDANCE
|