1
0
Эх сурвалжийг харах

memory : use sequential equal splits for recurrent modules (#16442)

Georgi Gerganov 3 сар өмнө
parent
commit
0123ff38f5

+ 3 - 1
src/llama-memory-hybrid.cpp

@@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch, false);
+                // TODO: non-sequential equal split can be done if using unified KV cache
+                //       for simplicity, we always use sequential equal split for now
+                ubatch = balloc.split_equal(n_ubatch, true);
             }
 
             if (ubatch.n_tokens == 0) {

+ 3 - 1
src/llama-memory-recurrent.cpp

@@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch, false);
+                // TODO: non-sequential equal split can be done if using unified KV cache
+                //       for simplicity, we always use sequential equal split for now
+                ubatch = balloc.split_equal(n_ubatch, true);
             }
 
             if (ubatch.n_tokens == 0) {