Browse Source

sampling: reuse token data buffer in llama_sampler_sample (#18365)

* sampling: reuse token data buffer in llama_sampler_sample

* move cur buffer before timing section, after samplers

* minor : fix build

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Jay Zenith 4 weeks ago
parent
commit
c32fa21db8
2 changed files with 47 additions and 33 deletions
  1. 44 33
      src/llama-sampling.cpp
  2. 3 0
      src/llama-sampling.h

+ 44 - 33
src/llama-sampling.cpp

@@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) {
     delete smpl;
 }
 
-llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
-    const auto * logits = llama_get_logits_ith(ctx, idx);
-
-    const llama_model * model = llama_get_model(ctx);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
-    const int n_vocab = llama_vocab_n_tokens(vocab);
-
-    // TODO: do not allocate each time
-    std::vector<llama_token_data> cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-    }
-
-    llama_token_data_array cur_p = {
-        /* .data       = */ cur.data(),
-        /* .size       = */ cur.size(),
-        /* .selected   = */ -1,
-        /* .sorted     = */ false,
-    };
-
-    llama_sampler_apply(smpl, &cur_p);
-
-    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
-
-    auto token = cur_p.data[cur_p.selected].id;
-
-    llama_sampler_accept(smpl, token);
-
-    return token;
-}
-
 // sampler chain
 
 static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
             /* .samplers    = */ {},
+            /* .cur         = */ {},
             /* .t_sample_us = */ 0,
             /* .n_sample    = */ 0,
         }
     );
 }
 
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+    const auto * logits = llama_get_logits_ith(ctx, idx);
+
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
+
+    // use pre-allocated buffer from chain if available, otherwise allocate locally
+    std::vector<llama_token_data> * cur_ptr;
+    std::vector<llama_token_data> cur_local;
+
+    if (smpl->iface == &llama_sampler_chain_i) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+        cur_ptr = &chain->cur;
+    } else {
+        cur_ptr = &cur_local;
+    }
+
+    auto & cur = *cur_ptr;
+    cur.resize(n_vocab);
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+    }
+
+    llama_token_data_array cur_p = {
+        /* .data       = */ cur.data(),
+        /* .size       = */ cur.size(),
+        /* .selected   = */ -1,
+        /* .sorted     = */ false,
+    };
+
+    llama_sampler_apply(smpl, &cur_p);
+
+    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+    auto token = cur_p.data[cur_p.selected].id;
+
+    llama_sampler_accept(smpl, token);
+
+    return token;
+}
+
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
     auto * p = (llama_sampler_chain *) chain->ctx;
     p->samplers.push_back(smpl);

+ 3 - 0
src/llama-sampling.h

@@ -16,6 +16,9 @@ struct llama_sampler_chain {
 
     std::vector<struct llama_sampler *> samplers;
 
+    // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations
+    std::vector<llama_token_data> cur;
+
     // timing
 
     mutable int64_t t_sample_us;