Sfoglia il codice sorgente

llama : speedup tokenization (#2831)

* Speedup tokenization

On current master it takes ~3.2 seconds to tokenize
Wikitext. With this change it becomes ~525 ms.

* Fixit: it was missing the piece after the last found occurence

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Kawrakow 2 anni fa
parent
commit
463173a6c0
2 ha cambiato i file con 14 aggiunte e 5 eliminazioni
  1. 4 0
      examples/perplexity/perplexity.cpp
  2. 10 5
      llama.cpp

+ 4 - 0
examples/perplexity/perplexity.cpp

@@ -190,10 +190,14 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
     const bool add_bos = is_spm;
 
+    auto tim1 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
     auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
 
+    auto tim2 = std::chrono::high_resolution_clock::now();
+    fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
+
     const int n_chunk_max = tokens.size() / params.n_ctx;
 
     const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);

+ 10 - 5
llama.cpp

@@ -114,12 +114,17 @@ static size_t utf8_len(char src) {
 }
 
 void replace_all(std::string & s, const std::string & search, const std::string & replace) {
-    for (size_t pos = 0; ; pos += replace.length()) {
-        pos = s.find(search, pos);
-        if (pos == std::string::npos) break;
-        s.erase(pos, search.length());
-        s.insert(pos, replace);
+    std::string result;
+    for (size_t pos = 0; ; pos += search.length()) {
+        auto new_pos = s.find(search, pos);
+        if (new_pos == std::string::npos) {
+            result += s.substr(pos, s.size() - pos);
+            break;
+        }
+        result += s.substr(pos, new_pos - pos) + replace;
+        pos = new_pos;
     }
+    s = std::move(result);
 }
 
 static void zeros(std::ofstream & file, size_t n) {