Просмотр исходного кода

llama : infill sampling handle very long tokens (#9924)

* llama : infill sampling handle very long tokens

ggml-ci

* cont : better indices

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
99bd4ac28c
4 измененных файлов с 35 добавлено и 43 удалено
  1. 0 6
      include/llama.h
  2. 35 13
      src/llama-sampling.cpp
  3. 0 17
      src/llama-vocab.cpp
  4. 0 7
      src/llama.cpp

+ 0 - 6
include/llama.h

@@ -953,12 +953,6 @@ extern "C" {
                                int32_t   lstrip,
                                int32_t   lstrip,
                                   bool   special);
                                   bool   special);
 
 
-    // check if token0 is contained as a prefix in token1
-    LLAMA_API bool llama_token_is_prefix(
-              const struct llama_model * model,
-                           llama_token   token0,
-                           llama_token   token1);
-
     /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
     /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
     /// @param text The char pointer must be large enough to hold the resulting text.
     /// @param text The char pointer must be large enough to hold the resulting text.
     /// @return Returns the number of chars/bytes on success, no more than text_len_max.
     /// @return Returns the number of chars/bytes on success, no more than text_len_max.

+ 35 - 13
src/llama-sampling.cpp

@@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
 
 
 struct llama_sampler_infill {
 struct llama_sampler_infill {
     const struct llama_vocab * vocab;
     const struct llama_vocab * vocab;
+
+    std::vector<char> buf0;
+    std::vector<char> buf1;
 };
 };
 
 
 static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
 static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     size_t n_combined = 0; GGML_UNUSED(n_combined);
     size_t n_combined = 0; GGML_UNUSED(n_combined);
 
 
     // combine tokens with common prefix
     // combine tokens with common prefix
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        for (size_t j = 0; j < cur_p->size; ++j) {
-            if (cur_p->data[i].logit == -INFINITY) {
+    for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
+        for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
+            if (cur_p->data[i0].logit == -INFINITY) {
                 break;
                 break;
             }
             }
 
 
-            if (i == j || cur_p->data[j].logit == -INFINITY) {
+            if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
                 continue;
                 continue;
             }
             }
 
 
-            if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
-                if (cur_p->data[i].p >  cur_p->data[j].p) {
-                    cur_p->data[i].p += cur_p->data[j].p;
-                    cur_p->data[j].logit = -INFINITY;
-                    cur_p->data[j].p     = 0.0f;
-                } else {
-                    cur_p->data[j].p += cur_p->data[i].p;
-                    cur_p->data[i].logit = -INFINITY;
-                    cur_p->data[i].p     = 0.0f;
+            int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+            if (len0 < 0) {
+                ctx->buf0.resize(len0);
+                len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+                assert(len0 > 0);
+            }
+
+            int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+            if (len1 < 0) {
+                ctx->buf1.resize(len1);
+                len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+                assert(len1 > 0);
+            }
+
+            // token i0 is a prefix of token i1
+            if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
+                int dst = i0;
+                int src = i1;
+
+                // merge into the token with higher probability
+                if (cur_p->data[i1].p > cur_p->data[i0].p) {
+                    std::swap(dst, src);
                 }
                 }
 
 
+                cur_p->data[dst].p += cur_p->data[src].p;
+                cur_p->data[src].logit = -INFINITY;
+                cur_p->data[src].p     = 0.0f;
+
                 n_combined++;
                 n_combined++;
             }
             }
         }
         }
@@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
         /* .iface = */ &llama_sampler_infill_i,
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
         /* .ctx   = */ new llama_sampler_infill {
             /* .vocab = */ &vocab,
             /* .vocab = */ &vocab,
+            /* .buf0 = */ std::vector<char>(512),
+            /* .buf1 = */ std::vector<char>(512),
         },
         },
     };
     };
 }
 }

+ 0 - 17
src/llama-vocab.cpp

@@ -1858,23 +1858,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
     return 0;
 }
 }
 
 
-bool llama_token_is_prefix_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token0,
-                     llama_token   token1) {
-    char text_buf_0[128];
-    char text_buf_1[128];
-
-    const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
-    const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
-
-    if (len0 <= 0 || len1 <= 0) {
-        return false;
-    }
-
-    return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
-}
-
 int32_t llama_detokenize_impl(
 int32_t llama_detokenize_impl(
         const struct llama_vocab & vocab,
         const struct llama_vocab & vocab,
                const llama_token * tokens,
                const llama_token * tokens,

+ 0 - 7
src/llama.cpp

@@ -21466,13 +21466,6 @@ int32_t llama_token_to_piece(
     return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
     return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
 }
 }
 
 
-bool llama_token_is_prefix(
-    const struct llama_model * model,
-                 llama_token   token0,
-                 llama_token   token1) {
-    return llama_token_is_prefix_impl(model->vocab, token0, token1);
-}
-
 int32_t llama_detokenize(
 int32_t llama_detokenize(
     const struct llama_model * model,
     const struct llama_model * model,
            const llama_token * tokens,
            const llama_token * tokens,