|
|
@@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
|
return n_outputs;
|
|
|
}
|
|
|
|
|
|
+uint32_t llama_batch_allocr::get_n_used() const {
|
|
|
+ return n_used;
|
|
|
+}
|
|
|
+
|
|
|
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
|
|
return out_ids;
|
|
|
}
|
|
|
@@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
void llama_batch_allocr::split_reset() {
|
|
|
out_ids.clear();
|
|
|
|
|
|
+ n_used = 0;
|
|
|
+
|
|
|
used.clear();
|
|
|
used.resize(get_n_tokens(), false);
|
|
|
|
|
|
@@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
|
idxs.push_back(cur_idx);
|
|
|
|
|
|
used[cur_idx] = true;
|
|
|
+ ++n_used;
|
|
|
|
|
|
++cur_idx;
|
|
|
|
|
|
@@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
idxs_per_seq[s].push_back(idx);
|
|
|
|
|
|
used[idx] = true;
|
|
|
+ ++n_used;
|
|
|
|
|
|
++cur_idx[s];
|
|
|
}
|
|
|
@@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|
|
idxs.push_back(cur_idx);
|
|
|
|
|
|
used[cur_idx] = true;
|
|
|
+ ++n_used;
|
|
|
|
|
|
if (idxs.size() >= n_ubatch) {
|
|
|
break;
|