| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- #pragma once
- #include "llama.h"
- #include "llama-cparams.h"
- #include <bitset>
- #include <cassert>
- #include <cstring>
- #include <map>
- #include <set>
- #include <vector>
- struct llama_kv_cell_ext {
- // 2D spatial positions, typically used for M-RoPE
- llama_pos x = 0;
- llama_pos y = 0;
- // return true if the current 2D spatial position is greater than other
- bool is_2d_gt(llama_pos ox, llama_pos oy) const {
- return (y > oy) || (y == oy && x > ox);
- }
- void reset() {
- static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
- memset(this, 0, sizeof(*this));
- }
- };
- // meta information about KV cells that can be part of multiple sequences at the same time
- // TODO: add unit tests
- class llama_kv_cells {
- public:
- void reset() {
- for (uint32_t i = 0; i < pos.size(); ++i) {
- pos[i] = -1;
- ext[i].reset();
- shift[i] = 0;
- seq[i].reset();
- }
- has_shift = false;
- used.clear();
- for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
- seq_pos[s].clear();
- }
- }
- void reset_shift() {
- has_shift = false;
- for (uint32_t i = 0; i < shift.size(); ++i) {
- shift[i] = 0;
- }
- }
- uint32_t size() const {
- return pos.size();
- }
- void resize(uint32_t n) {
- pos.resize(n);
- ext.resize(n);
- shift.resize(n);
- seq.resize(n);
- reset();
- }
- bool is_empty(uint32_t i) const {
- assert(i < pos.size());
- assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
- return pos[i] == -1;
- }
- uint32_t get_used() const {
- return used.size();
- }
- // the index of the first cell that is used
- // return 0 if no cells are used
- uint32_t used_min() const {
- return used.empty() ? 0 : *used.begin();
- }
- // the index of the last cell that is used + 1
- // return 0 if no cells are used
- uint32_t used_max_p1() const {
- return used.empty() ? 0 : *used.rbegin() + 1;
- }
- bool get_has_shift() const {
- return has_shift;
- }
- // move cell isrc to idst (used during defrag)
- //void mv(uint32_t isrc, uint32_t idst) {
- // assert(isrc < pos.size());
- // assert(idst < pos.size());
- // assert(pos[idst] == -1);
- // assert(pos[isrc] != -1);
- // pos [idst] = pos [isrc];
- // shift[idst] = shift[isrc];
- // seq [idst] = seq [isrc];
- // pos [isrc] = -1;
- // shift[isrc] = 0;
- // seq [isrc].reset();
- // used.erase (isrc);
- // used.insert(idst);
- //}
- // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
- llama_kv_cells cp(uint32_t i, uint32_t n) const {
- assert(i + n <= pos.size());
- llama_kv_cells res;
- res.resize(n);
- for (uint32_t j = 0; j < n; ++j) {
- const auto idx = i + j;
- res.pos[j] = pos[idx];
- res.ext[j] = ext[idx];
- res.seq[j] = seq[idx];
- assert(shift[idx] == 0);
- }
- return res;
- }
- // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
- llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
- llama_kv_cells res;
- res.resize(idxs.size());
- for (uint32_t j = 0; j < idxs.size(); ++j) {
- const auto idx = idxs[j];
- res.pos[j] = pos[idx];
- res.ext[j] = ext[idx];
- res.seq[j] = seq[idx];
- assert(shift[idx] == 0);
- }
- return res;
- }
- // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
- void set(uint32_t i, const llama_kv_cells & other) {
- assert(i + other.pos.size() <= pos.size());
- for (uint32_t j = 0; j < other.pos.size(); ++j) {
- const auto idx = i + j;
- if (pos[idx] == -1 && other.pos[j] != -1) {
- used.insert(i + j);
- }
- if (pos[idx] != -1 && other.pos[j] == -1) {
- used.erase(i + j);
- }
- if (pos[idx] != -1) {
- seq_pos_rm(i + j);
- }
- pos[idx] = other.pos[j];
- ext[idx] = other.ext[j];
- seq[idx] = other.seq[j];
- if (pos[idx] != -1) {
- seq_pos_add(i + j);
- }
- assert(shift[idx] == 0);
- }
- }
- // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
- void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
- assert(idxs.size() == other.pos.size());
- for (uint32_t j = 0; j < other.pos.size(); ++j) {
- const auto idx = idxs[j];
- if (pos[idx] == -1 && other.pos[j] != -1) {
- used.insert(idx);
- }
- if (pos[idx] != -1 && other.pos[j] == -1) {
- used.erase(idx);
- }
- if (pos[idx] != -1) {
- seq_pos_rm(idx);
- }
- pos[idx] = other.pos[j];
- ext[idx] = other.ext[j];
- seq[idx] = other.seq[j];
- if (pos[idx] != -1) {
- seq_pos_add(idx);
- }
- assert(shift[idx] == 0);
- }
- }
- // clear a non-empty cell
- void rm(uint32_t i) {
- assert(i < pos.size());
- assert(pos[i] != -1);
- seq_pos_rm(i);
- seq[i].reset();
- pos[i] = -1;
- ext[i].reset();
- shift[i] = 0;
- used.erase(i);
- }
- // note: call only if the cell has seq_id
- // return true if the cell becomes empty
- bool seq_rm(uint32_t i, llama_seq_id seq_id) {
- assert(i < pos.size());
- assert(seq[i].test(seq_id));
- assert(pos[i] != -1);
- assert(seq_id >= 0);
- seq[i].reset(seq_id);
- seq_pos_dec(seq_id, pos[i]);
- if (seq[i].none()) {
- pos[i] = -1;
- ext[i].reset();
- shift[i] = 0;
- used.erase(i);
- return true;
- }
- return false;
- }
- // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
- bool seq_keep(uint32_t i, llama_seq_id seq_id) {
- assert(i < pos.size());
- if (seq[i].test(seq_id)) {
- seq_pos_rm(i);
- seq[i].reset();
- seq[i].set(seq_id);
- seq_pos_inc(seq_id, pos[i]);
- return false;
- }
- if (seq[i].any()) {
- seq_pos_rm(i);
- seq[i].reset();
- pos[i] = -1;
- ext[i].reset();
- shift[i] = 0;
- used.erase(i);
- return true;
- }
- assert(pos[i] == -1);
- return false;
- }
- // number of different sequences in the cell
- int seq_count(uint32_t i) const {
- assert(i < pos.size());
- assert(pos[i] != -1);
- return seq[i].count();
- }
- // check if the cell contains seq_id
- bool seq_has(uint32_t i, llama_seq_id seq_id) const {
- assert(i < pos.size());
- assert(seq_id >= 0);
- return seq[i].test(seq_id);
- }
- // note: call only if the cell is not empty and the seq_id is not in the cell
- void seq_add(uint32_t i, llama_seq_id seq_id) {
- assert(i < pos.size());
- assert(pos[i] != -1);
- assert(!seq[i].test(seq_id));
- seq[i].set(seq_id);
- seq_pos_inc(seq_id, pos[i]);
- }
- // return the sequence id of this cell
- // note: call only for cells with exactly one sequence
- llama_seq_id seq_get(uint32_t i) const {
- assert(seq[i].count() == 1);
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
- if (seq[i].test(s)) {
- return s;
- }
- }
- return -1;
- }
- // the minimum position of sequence seq_id currently present in any of the cells
- // return -1 if the sequence is not present
- llama_pos seq_pos_min(llama_seq_id seq_id) const {
- assert(seq_id >= 0);
- assert(seq_id < LLAMA_MAX_SEQ);
- if (seq_pos[seq_id].empty()) {
- return -1;
- }
- assert(seq_pos[seq_id].begin()->second > 0);
- return seq_pos[seq_id].begin()->first;
- }
- // the maximum position of sequence seq_id currently present in any of the cells
- // return -1 if the sequence is not present
- llama_pos seq_pos_max(llama_seq_id seq_id) const {
- assert(seq_id >= 0);
- assert(seq_id < LLAMA_MAX_SEQ);
- if (seq_pos[seq_id].empty()) {
- return -1;
- }
- assert(seq_pos[seq_id].rbegin()->second > 0);
- return seq_pos[seq_id].rbegin()->first;
- }
- // note: call only if the cell is not empty
- llama_pos pos_get(uint32_t i) const {
- assert(i < pos.size());
- assert(pos[i] != -1);
- return pos[i];
- }
- const llama_kv_cell_ext & ext_get(uint32_t i) const {
- assert(i < pos.size());
- assert(pos[i] != -1);
- return ext[i];
- }
- // note: call only if the cell is not empty
- llama_pos get_shift(uint32_t i) const {
- assert(i < pos.size());
- assert(pos[i] != -1);
- return shift[i];
- }
- // check if a cell is not empty and its position is within [p0, p1)
- bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
- assert(i < pos.size());
- return pos[i] >= p0 && pos[i] < p1;
- }
- // set the position of an empty cell
- // does not modify "has_shift"
- // note: call only if the cell is empty
- void pos_set(uint32_t i, llama_pos p) {
- assert(i < pos.size());
- assert(pos[i] == -1);
- assert(seq[i].none());
- pos[i] = p;
- used.insert(i);
- }
- void ext_set(uint32_t i, llama_kv_cell_ext p) {
- assert(i < ext.size());
- ext[i] = p;
- }
- // pos[i] = pos[i] + d
- // sets "has_shift" to true
- // note: call only if the cell is not empty
- bool pos_add(uint32_t i, llama_pos d) {
- assert(i < pos.size());
- assert(pos[i] != -1);
- seq_pos_rm(i);
- pos[i] += d;
- shift[i] += d;
- has_shift = true;
- if (pos[i] < 0) {
- seq[i].reset();
- pos[i] = -1;
- shift[i] = 0;
- used.erase(i);
- return true;
- }
- seq_pos_add(i);
- return false;
- }
- // pos[i] = pos[i] / d
- // sets "has_shift" to true
- // note: call only if the cell is not empty
- void pos_div(uint32_t i, int d) {
- assert(i < pos.size());
- assert(pos[i] != -1);
- const llama_pos p_old = pos[i];
- seq_pos_rm(i);
- pos[i] /= d;
- shift[i] += p_old - pos[i];
- seq_pos_add(i);
- has_shift = true;
- }
- private:
- bool has_shift = false;
- // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
- std::set<uint32_t> used;
- std::vector<llama_pos> pos;
- // stores extra info per cell
- std::vector<llama_kv_cell_ext> ext;
- // this array accumulates any applied shifts to the pos array since the last reset_shift() call
- // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
- //
- // cells.pos_add(x, shift_x);
- // cells.pos_div(y, shift_y);
- // ...
- //
- // if (cells.has_shift()) {
- // for (int i = 0; i < n; ++i) {
- // auto shift_i = cells.get_shift(i);
- // ...
- // }
- // cells.reset_shift();
- // }
- //
- std::vector<llama_pos> shift;
- using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
- // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
- std::vector<seq_set_t> seq;
- // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
- // if the position p is not present, seq_pos[s][p] is not set
- // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
- //
- // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
- // - during performing a cache reuse via (rm + add)
- // - some vision models have input embeddings with repeating positions
- //
- std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
- // helper functions for updating `seq_pos`, once cell at a time:
- void seq_pos_dec(llama_seq_id s, llama_pos p) {
- auto it = seq_pos[s].find(p);
- assert(it != seq_pos[s].end());
- if (--it->second == 0) {
- seq_pos[s].erase(it);
- }
- }
- void seq_pos_inc(llama_seq_id s, llama_pos p) {
- seq_pos[s][p]++;
- }
- // remove cell i
- void seq_pos_rm(uint32_t i) {
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
- if (seq[i].test(s)) {
- seq_pos_dec(s, pos[i]);
- }
- }
- }
- // add cell i
- void seq_pos_add(uint32_t i) {
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
- if (seq[i].test(s)) {
- seq_pos_inc(s, pos[i]);
- }
- }
- }
- };
|