|
|
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
|
|
|
|
|
// note: tracking the other way around is not necessary for now
|
|
|
//seq_cpl[s0][s1] = true;
|
|
|
+
|
|
|
+ has_cpl = true;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
|
return ubatch_add(idxs, idxs.size(), false);
|
|
|
}
|
|
|
|
|
|
-llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
+llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
|
|
+ if (sequential && has_cpl) {
|
|
|
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
|
|
+
|
|
|
+ return {};
|
|
|
+ }
|
|
|
+
|
|
|
std::vector<seq_set_t> cur_seq_set;
|
|
|
|
|
|
+ llama_seq_id last_seq_id = -1;
|
|
|
+
|
|
|
// determine the non-overlapping sequence sets participating in this ubatch
|
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
if (used[i]) {
|
|
|
@@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // accept only increasing sequence ids
|
|
|
+ if (sequential) {
|
|
|
+ add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
|
|
+ }
|
|
|
+
|
|
|
if (add) {
|
|
|
cur_seq_set.push_back(seq_set[i]);
|
|
|
|
|
|
+ last_seq_id = batch.seq_id[i][0];
|
|
|
+
|
|
|
if (cur_seq_set.size() > n_ubatch) {
|
|
|
break;
|
|
|
}
|