Jelajahi Sumber

common : preallocate sampling token data vector (#8363)

`emplace_back` repeatedly-called is slower than preallocating the vector to the vocab size and directly inserting the data. Some rudimentary profiling with `chrono` improves the performance of this block of code from ~500us/op to ~40us/op.

Overall, this slightly improves the sampling performance which has a more substantial impact for the `examples/lookahead` implementation -- I am able to see a ~10% performance boost in lookahead inference.
Kevin Wang 1 tahun lalu
induk
melakukan
470939d483
1 mengubah file dengan 3 tambahan dan 3 penghapusan
  1. 3 3
      common/sampling.cpp

+ 3 - 3
common/sampling.cpp

@@ -378,7 +378,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
     if (ctx_sampling->grammar != NULL && !apply_grammar) {
     if (ctx_sampling->grammar != NULL && !apply_grammar) {
         GGML_ASSERT(original_logits != NULL);
         GGML_ASSERT(original_logits != NULL);
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
-        *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+        *original_logits = {logits, logits + n_vocab};
     }
     }
 
 
     // apply params.logit_bias map
     // apply params.logit_bias map
@@ -391,10 +391,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
         llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
         llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
     }
     }
 
 
-    cur.clear();
+    cur.resize(n_vocab);
 
 
     for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
     for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
     }
     }
 
 
     llama_token_data_array cur_p = { cur.data(), cur.size(), false };
     llama_token_data_array cur_p = { cur.data(), cur.size(), false };