Просмотр исходного кода

batch : fix sequence id ownership (#17915)

* batch : fix sequence id ownage

* cont : reduce allocations
Georgi Gerganov 1 месяц назад
Родитель
Сommit
d9f8f60618
2 измененных файлов с 16 добавлено и 4 удалено
  1. 12 2
      src/llama-batch.cpp
  2. 4 2
      src/llama-batch.h

+ 12 - 2
src/llama-batch.cpp

@@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
     udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
     udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
     udata->output    .resize(n_tokens);
     udata->output    .resize(n_tokens);
 
 
+    udata->seq_id_data.reserve(n_tokens);
+
     seq_set_t seq_set_unq;
     seq_set_t seq_set_unq;
 
 
     for (size_t i = 0; i < idxs.size(); ++i) {
     for (size_t i = 0; i < idxs.size(); ++i) {
@@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
         }
 
 
         udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
         udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
-        udata->seq_id[i]   = batch.seq_id[idxs[i]];
         udata->output[i]   = batch.logits[idxs[i]];
         udata->output[i]   = batch.logits[idxs[i]];
 
 
         for (int s = 0; s < udata->n_seq_id[i]; ++s) {
         for (int s = 0; s < udata->n_seq_id[i]; ++s) {
-            seq_set_unq.set(udata->seq_id[i][s]);
+            const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
+
+            udata->seq_id_data.push_back(seq_id);
+            seq_set_unq.set(seq_id);
         }
         }
 
 
         if (udata->output[i]) {
         if (udata->output[i]) {
@@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
         }
         }
     }
     }
 
 
+    llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
+    for (size_t i = 0; i < idxs.size(); ++i) {
+        udata->seq_id[i] = seq_id_ptr;
+        seq_id_ptr += udata->n_seq_id[i];
+    }
+
     for (uint32_t s = 0; s < n_seq_max; ++s) {
     for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
         if (seq_set_unq.test(s)) {
             udata->seq_idx[s] = udata->seq_id_unq.size();
             udata->seq_idx[s] = udata->seq_id_unq.size();

+ 4 - 2
src/llama-batch.h

@@ -56,13 +56,15 @@ struct llama_ubatch {
         std::vector<float>          embd;
         std::vector<float>          embd;
         std::vector<llama_pos>      pos;
         std::vector<llama_pos>      pos;
         std::vector<int32_t>        n_seq_id;
         std::vector<int32_t>        n_seq_id;
-        std::vector<llama_seq_id *> seq_id;
+        std::vector<llama_seq_id *> seq_id;      // these point into the seq_id_data below
         std::vector<llama_seq_id>   seq_id_unq;
         std::vector<llama_seq_id>   seq_id_unq;
         std::vector<int32_t>        seq_idx;
         std::vector<int32_t>        seq_idx;
         std::vector<int8_t>         output;
         std::vector<int8_t>         output;
+
+        std::vector<llama_seq_id> seq_id_data;
     };
     };
 
 
-    // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
+    // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
     std::shared_ptr<data_t> data;
     std::shared_ptr<data_t> data;
 };
 };