|
|
@@ -7,6 +7,7 @@
|
|
|
#include <cassert>
|
|
|
#include <vector>
|
|
|
#include <set>
|
|
|
+#include <map>
|
|
|
|
|
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
|
// TODO: add unit tests
|
|
|
@@ -164,7 +165,7 @@ public:
|
|
|
assert(seq_id >= 0);
|
|
|
|
|
|
seq[i].reset(seq_id);
|
|
|
- seq_pos[seq_id].erase(pos[i]);
|
|
|
+ seq_pos_dec(seq_id, pos[i]);
|
|
|
|
|
|
if (seq[i].none()) {
|
|
|
pos[i] = -1;
|
|
|
@@ -187,7 +188,7 @@ public:
|
|
|
seq[i].reset();
|
|
|
|
|
|
seq[i].set(seq_id);
|
|
|
- seq_pos[seq_id].insert(pos[i]);
|
|
|
+ seq_pos_inc(seq_id, pos[i]);
|
|
|
|
|
|
return false;
|
|
|
}
|
|
|
@@ -232,7 +233,7 @@ public:
|
|
|
assert(!seq[i].test(seq_id));
|
|
|
|
|
|
seq[i].set(seq_id);
|
|
|
- seq_pos[seq_id].insert(pos[i]);
|
|
|
+ seq_pos_inc(seq_id, pos[i]);
|
|
|
}
|
|
|
|
|
|
// return the sequence id of this cell
|
|
|
@@ -259,7 +260,9 @@ public:
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
|
- return *seq_pos[seq_id].begin();
|
|
|
+ assert(seq_pos[seq_id].begin()->second > 0);
|
|
|
+
|
|
|
+ return seq_pos[seq_id].begin()->first;
|
|
|
}
|
|
|
|
|
|
// the maximum position of sequence seq_id currently present in any of the cells
|
|
|
@@ -272,7 +275,9 @@ public:
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
|
- return *seq_pos[seq_id].rbegin();
|
|
|
+ assert(seq_pos[seq_id].rbegin()->second > 0);
|
|
|
+
|
|
|
+ return seq_pos[seq_id].rbegin()->first;
|
|
|
}
|
|
|
|
|
|
// note: call only if the cell is not empty
|
|
|
@@ -389,17 +394,36 @@ private:
|
|
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
|
std::vector<seq_set_t> seq;
|
|
|
|
|
|
- // the set seq_pos[s] tells us which positions are currently present for sequence s
|
|
|
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
|
|
+ // if the position p is not present, seq_pos[s][p] is not set
|
|
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
|
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
|
|
+ //
|
|
|
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
|
|
+ // - during performing a cache reuse via (rm + add)
|
|
|
+ // - some vision models have input embeddings with repeating positions
|
|
|
+ //
|
|
|
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
|
|
|
|
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
|
|
|
|
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
|
|
+ auto it = seq_pos[s].find(p);
|
|
|
+ assert(it != seq_pos[s].end());
|
|
|
+
|
|
|
+ if (--it->second == 0) {
|
|
|
+ seq_pos[s].erase(it);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
|
|
+ seq_pos[s][p]++;
|
|
|
+ }
|
|
|
+
|
|
|
// remove cell i
|
|
|
void seq_pos_rm(uint32_t i) {
|
|
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq[i].test(s)) {
|
|
|
- seq_pos[s].erase(pos[i]);
|
|
|
+ seq_pos_dec(s, pos[i]);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -408,7 +432,7 @@ private:
|
|
|
void seq_pos_add(uint32_t i) {
|
|
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq[i].test(s)) {
|
|
|
- seq_pos[s].insert(pos[i]);
|
|
|
+ seq_pos_inc(s, pos[i]);
|
|
|
}
|
|
|
}
|
|
|
}
|