|
@@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
|
udata->output .resize(n_tokens);
|
|
udata->output .resize(n_tokens);
|
|
|
|
|
|
|
|
|
|
+ udata->seq_id_data.reserve(n_tokens);
|
|
|
|
|
+
|
|
|
seq_set_t seq_set_unq;
|
|
seq_set_t seq_set_unq;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
@@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
|
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
|
|
- udata->seq_id[i] = batch.seq_id[idxs[i]];
|
|
|
|
|
udata->output[i] = batch.logits[idxs[i]];
|
|
udata->output[i] = batch.logits[idxs[i]];
|
|
|
|
|
|
|
|
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
|
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
|
|
- seq_set_unq.set(udata->seq_id[i][s]);
|
|
|
|
|
|
|
+ const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
|
|
|
|
|
+
|
|
|
|
|
+ udata->seq_id_data.push_back(seq_id);
|
|
|
|
|
+ seq_set_unq.set(seq_id);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (udata->output[i]) {
|
|
if (udata->output[i]) {
|
|
@@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
|
|
|
|
|
+ for (size_t i = 0; i < idxs.size(); ++i) {
|
|
|
|
|
+ udata->seq_id[i] = seq_id_ptr;
|
|
|
|
|
+ seq_id_ptr += udata->n_seq_id[i];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
|
if (seq_set_unq.test(s)) {
|
|
if (seq_set_unq.test(s)) {
|
|
|
udata->seq_idx[s] = udata->seq_id_unq.size();
|
|
udata->seq_idx[s] = udata->seq_id_unq.size();
|