Просмотр исходного кода

memory : fix broken batch splits for recurrent cache (#14575)

Splits producing more than one ubatch per batch for recurrent models
were broken with #14512.

This fixes it by moving the completeness check after the ubatch split loop.
compilade 6 месяцев назад
Родитель
Сommit
bb4f7a9e4e
1 измененных файлов с 6 добавлено и 2 удалено
  1. 6 2
      src/llama-memory-recurrent.cpp

+ 6 - 2
src/llama-memory-recurrent.cpp

@@ -377,14 +377,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
                 ubatch = balloc.split_equal(n_ubatch, false);
             }
 
-            if (balloc.get_n_used() < balloc.get_n_tokens()) {
-                // failed to find a suitable split
+            if (ubatch.n_tokens == 0) {
                 break;
             }
 
             ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
         if (!prepare(ubatches)) {
             break;
         }