Просмотр исходного кода

recurrent : call balloc split_reset() in init_batch() (#14414)

ggml-ci
Georgi Gerganov 6 месяцев назад
Родитель
Сommit
43678060c1
1 измененных файлов с 21 добавлено и 16 удалено
  1. 21 16
      src/llama-memory-recurrent.cpp

+ 21 - 16
src/llama-memory-recurrent.cpp

@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
 }
 
 llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
-    std::vector<llama_ubatch> ubatches;
+    do {
+        balloc.split_reset();
 
-    while (true) {
-        llama_ubatch ubatch;
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            llama_ubatch ubatch;
 
-        if (embd_all) {
-            // if all tokens are output, split by sequence
-            ubatch = balloc.split_seq(n_ubatch);
-        } else {
-            ubatch = balloc.split_equal(n_ubatch);
+            if (embd_all) {
+                // if all tokens are output, split by sequence
+                ubatch = balloc.split_seq(n_ubatch);
+            } else {
+                ubatch = balloc.split_equal(n_ubatch);
+            }
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        if (ubatch.n_tokens == 0) {
+        if (!prepare(ubatches)) {
             break;
         }
 
-        ubatches.push_back(std::move(ubatch)); // NOLINT
-    }
-
-    if (!prepare(ubatches)) {
-        return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+    } while (false);
 
-    return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_context_ptr llama_memory_recurrent::init_full() {