Przeglądaj źródła

memory : use sequential equal splits for recurrent modules (#16442)

Georgi Gerganov 3 miesięcy temu
rodzic
commit
0123ff38f5
2 zmienionych plików z 6 dodań i 2 usunięć
  1. 3 1
      src/llama-memory-hybrid.cpp
  2. 3 1
      src/llama-memory-recurrent.cpp

+ 3 - 1
src/llama-memory-hybrid.cpp

@@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch, false);
+                // TODO: non-sequential equal split can be done if using unified KV cache
+                //       for simplicity, we always use sequential equal split for now
+                ubatch = balloc.split_equal(n_ubatch, true);
             }
 
             if (ubatch.n_tokens == 0) {

+ 3 - 1
src/llama-memory-recurrent.cpp

@@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
                 // if all tokens are output, split by sequence
                 ubatch = balloc.split_seq(n_ubatch);
             } else {
-                ubatch = balloc.split_equal(n_ubatch, false);
+                // TODO: non-sequential equal split can be done if using unified KV cache
+                //       for simplicity, we always use sequential equal split for now
+                ubatch = balloc.split_equal(n_ubatch, true);
             }
 
             if (ubatch.n_tokens == 0) {