|
|
@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
}
|
|
|
|
|
|
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
|
- std::vector<llama_ubatch> ubatches;
|
|
|
+ do {
|
|
|
+ balloc.split_reset();
|
|
|
|
|
|
- while (true) {
|
|
|
- llama_ubatch ubatch;
|
|
|
+ std::vector<llama_ubatch> ubatches;
|
|
|
+ while (true) {
|
|
|
+ llama_ubatch ubatch;
|
|
|
|
|
|
- if (embd_all) {
|
|
|
- // if all tokens are output, split by sequence
|
|
|
- ubatch = balloc.split_seq(n_ubatch);
|
|
|
- } else {
|
|
|
- ubatch = balloc.split_equal(n_ubatch);
|
|
|
+ if (embd_all) {
|
|
|
+ // if all tokens are output, split by sequence
|
|
|
+ ubatch = balloc.split_seq(n_ubatch);
|
|
|
+ } else {
|
|
|
+ ubatch = balloc.split_equal(n_ubatch);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ubatch.n_tokens == 0) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
|
}
|
|
|
|
|
|
- if (ubatch.n_tokens == 0) {
|
|
|
+ if (!prepare(ubatches)) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
- ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
|
- }
|
|
|
-
|
|
|
- if (!prepare(ubatches)) {
|
|
|
- return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
- }
|
|
|
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
|
|
+ } while (false);
|
|
|
|
|
|
- return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
|
|
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
}
|
|
|
|
|
|
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|