|
@@ -23,7 +23,7 @@ public:
|
|
|
|
|
|
|
|
used.clear();
|
|
used.clear();
|
|
|
|
|
|
|
|
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
|
|
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
seq_pos[s].clear();
|
|
seq_pos[s].clear();
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -240,7 +240,7 @@ public:
|
|
|
llama_seq_id seq_get(uint32_t i) const {
|
|
llama_seq_id seq_get(uint32_t i) const {
|
|
|
assert(seq[i].count() == 1);
|
|
assert(seq[i].count() == 1);
|
|
|
|
|
|
|
|
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
|
|
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq[i].test(s)) {
|
|
if (seq[i].test(s)) {
|
|
|
return s;
|
|
return s;
|
|
|
}
|
|
}
|
|
@@ -253,7 +253,7 @@ public:
|
|
|
// return -1 if the sequence is not present
|
|
// return -1 if the sequence is not present
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
|
assert(seq_id >= 0);
|
|
assert(seq_id >= 0);
|
|
|
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
|
|
|
|
+ assert(seq_id < LLAMA_MAX_SEQ);
|
|
|
|
|
|
|
|
if (seq_pos[seq_id].empty()) {
|
|
if (seq_pos[seq_id].empty()) {
|
|
|
return -1;
|
|
return -1;
|
|
@@ -266,7 +266,7 @@ public:
|
|
|
// return -1 if the sequence is not present
|
|
// return -1 if the sequence is not present
|
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
|
|
assert(seq_id >= 0);
|
|
assert(seq_id >= 0);
|
|
|
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
|
|
|
|
+ assert(seq_id < LLAMA_MAX_SEQ);
|
|
|
|
|
|
|
|
if (seq_pos[seq_id].empty()) {
|
|
if (seq_pos[seq_id].empty()) {
|
|
|
return -1;
|
|
return -1;
|
|
@@ -384,20 +384,20 @@ private:
|
|
|
//
|
|
//
|
|
|
std::vector<llama_pos> shift;
|
|
std::vector<llama_pos> shift;
|
|
|
|
|
|
|
|
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
|
|
|
|
|
|
|
+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
|
|
|
|
|
|
|
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
|
std::vector<bits_t> seq;
|
|
std::vector<bits_t> seq;
|
|
|
|
|
|
|
|
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
|
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
|
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
|
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
|
|
|
|
|
|
|
+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
|
|
|
|
|
|
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
|
|
|
|
|
|
// remove cell i
|
|
// remove cell i
|
|
|
void seq_pos_rm(uint32_t i) {
|
|
void seq_pos_rm(uint32_t i) {
|
|
|
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
|
|
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq[i].test(s)) {
|
|
if (seq[i].test(s)) {
|
|
|
seq_pos[s].erase(pos[i]);
|
|
seq_pos[s].erase(pos[i]);
|
|
|
}
|
|
}
|
|
@@ -406,7 +406,7 @@ private:
|
|
|
|
|
|
|
|
// add cell i
|
|
// add cell i
|
|
|
void seq_pos_add(uint32_t i) {
|
|
void seq_pos_add(uint32_t i) {
|
|
|
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
|
|
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq[i].test(s)) {
|
|
if (seq[i].test(s)) {
|
|
|
seq_pos[s].insert(pos[i]);
|
|
seq_pos[s].insert(pos[i]);
|
|
|
}
|
|
}
|