Bladeren bron

perplexity: avoid unnecessary alloocations and logit copies (#5035)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Kawrakow 2 jaren geleden
bovenliggende
commit
993fba8180
1 gewijzigde bestanden met toevoegingen van 15 en 7 verwijderingen
  1. 15 7
      examples/perplexity/perplexity.cpp

+ 15 - 7
examples/perplexity/perplexity.cpp

@@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     double nll = 0.0;
     double nll = 0.0;
     double nll2 = 0.0;
     double nll2 = 0.0;
 
 
+    const int num_batches = (n_ctx + n_batch - 1) / n_batch;
+
+    std::vector<float> logits;
+    if (num_batches > 1) {
+        logits.reserve((size_t)n_ctx * n_vocab);
+    }
+
     fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
     fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
 
 
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
     std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         const int start =     i * n_ctx;
         const int start =     i * n_ctx;
         const int end   = start + n_ctx;
         const int end   = start + n_ctx;
 
 
-        const int num_batches = (n_ctx + n_batch - 1) / n_batch;
-
-        std::vector<float> logits;
-
         const auto t_start = std::chrono::high_resolution_clock::now();
         const auto t_start = std::chrono::high_resolution_clock::now();
 
 
         // clear the KV cache
         // clear the KV cache
@@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             // restore the original token in case it was set to BOS
             // restore the original token in case it was set to BOS
             tokens[batch_start] = token_org;
             tokens[batch_start] = token_org;
 
 
-            const auto * batch_logits = llama_get_logits(ctx);
-            logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            if (num_batches > 1) {
+                const auto * batch_logits = llama_get_logits(ctx);
+                logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+            }
         }
         }
 
 
         const auto t_end = std::chrono::high_resolution_clock::now();
         const auto t_end = std::chrono::high_resolution_clock::now();
@@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
         // last 256 tokens.  Then, we split the input up into context window size chunks to
         // last 256 tokens.  Then, we split the input up into context window size chunks to
         // process the entire prompt.
         // process the entire prompt.
         const int first = n_ctx/2;
         const int first = n_ctx/2;
-        process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
+        const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
+        process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
                        workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
                        workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
         count += n_ctx - first - 1;
         count += n_ctx - first - 1;
 
 
@@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
             printf("%8d  %.4lf  %4lf  %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
         }
         }
         fflush(stdout);
         fflush(stdout);
+
+        logits.clear();
     }
     }
     printf("\n");
     printf("\n");