llama-kv-cells.h 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-cparams.h"
  4. #include <bitset>
  5. #include <cassert>
  6. #include <vector>
  7. #include <set>
  8. // meta information about KV cells that can be part of multiple sequences at the same time
  9. // TODO: add unit tests
  10. class llama_kv_cells_unified {
  11. public:
  12. void reset() {
  13. for (uint32_t i = 0; i < pos.size(); ++i) {
  14. pos[i] = -1;
  15. shift[i] = 0;
  16. seq[i].reset();
  17. }
  18. has_shift = false;
  19. used.clear();
  20. for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
  21. seq_pos[s].clear();
  22. }
  23. }
  24. void reset_shift() {
  25. has_shift = false;
  26. for (uint32_t i = 0; i < shift.size(); ++i) {
  27. shift[i] = 0;
  28. }
  29. }
  30. uint32_t size() const {
  31. return pos.size();
  32. }
  33. void resize(uint32_t n) {
  34. pos.resize(n);
  35. shift.resize(n);
  36. seq.resize(n);
  37. reset();
  38. }
  39. bool is_empty(uint32_t i) const {
  40. assert(i < pos.size());
  41. assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
  42. return pos[i] == -1;
  43. }
  44. uint32_t get_used() const {
  45. return used.size();
  46. }
  47. // the index of the first cell that is used
  48. // return 0 if no cells are used
  49. uint32_t used_min() const {
  50. return used.empty() ? 0 : *used.begin();
  51. }
  52. // the index of the last cell that is used + 1
  53. // return 0 if no cells are used
  54. uint32_t used_max_p1() const {
  55. return used.empty() ? 0 : *used.rbegin() + 1;
  56. }
  57. bool get_has_shift() const {
  58. return has_shift;
  59. }
  60. // move cell isrc to idst (used during defrag)
  61. void mv(uint32_t isrc, uint32_t idst) {
  62. assert(isrc < pos.size());
  63. assert(idst < pos.size());
  64. pos [idst] = pos [isrc];
  65. shift[idst] = shift[isrc];
  66. seq [idst] = seq [isrc];
  67. pos [isrc] = -1;
  68. shift[isrc] = 0;
  69. seq [isrc].reset();
  70. used.erase (isrc);
  71. used.insert(idst);
  72. }
  73. // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
  74. llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
  75. assert(i + n <= pos.size());
  76. llama_kv_cells_unified res;
  77. res.resize(n);
  78. for (uint32_t j = 0; j < n; ++j) {
  79. res.pos[j] = pos[i + j];
  80. res.seq[j] = seq[i + j];
  81. assert(shift[i + j] == 0);
  82. }
  83. return res;
  84. }
  85. // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
  86. void set(uint32_t i, const llama_kv_cells_unified & other) {
  87. assert(i + other.pos.size() <= pos.size());
  88. for (uint32_t j = 0; j < other.pos.size(); ++j) {
  89. if (pos[i + j] == -1 && other.pos[j] != -1) {
  90. used.insert(i + j);
  91. }
  92. if (pos[i + j] != -1 && other.pos[j] == -1) {
  93. used.erase(i + j);
  94. }
  95. if (pos[i + j] != -1) {
  96. seq_pos_rm(i + j);
  97. }
  98. pos[i + j] = other.pos[j];
  99. seq[i + j] = other.seq[j];
  100. if (pos[i + j] != -1) {
  101. seq_pos_add(i + j);
  102. }
  103. assert(shift[i + j] == 0);
  104. }
  105. }
  106. // clear a non-empty cell
  107. void rm(uint32_t i) {
  108. assert(i < pos.size());
  109. assert(pos[i] != -1);
  110. seq_pos_rm(i);
  111. pos[i] = -1;
  112. seq[i].reset();
  113. used.erase(i);
  114. }
  115. // note: call only if the cell has seq_id
  116. // return true if the cell becomes empty
  117. bool seq_rm(uint32_t i, llama_seq_id seq_id) {
  118. assert(i < pos.size());
  119. assert(seq[i].test(seq_id));
  120. assert(pos[i] != -1);
  121. assert(seq_id >= 0);
  122. seq[i].reset(seq_id);
  123. seq_pos[seq_id].erase(pos[i]);
  124. if (seq[i].none()) {
  125. pos[i] = -1;
  126. used.erase(i);
  127. return true;
  128. }
  129. return false;
  130. }
  131. // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
  132. bool seq_keep(uint32_t i, llama_seq_id seq_id) {
  133. assert(i < pos.size());
  134. if (seq[i].test(seq_id)) {
  135. seq_pos_rm(i);
  136. seq[i].reset();
  137. seq[i].set(seq_id);
  138. seq_pos[seq_id].insert(pos[i]);
  139. return false;
  140. }
  141. if (seq[i].any()) {
  142. seq_pos_rm(i);
  143. seq[i].reset();
  144. pos[i] = -1;
  145. used.erase(i);
  146. return true;
  147. }
  148. assert(pos[i] == -1);
  149. return false;
  150. }
  151. // number of different sequences in the cell
  152. int seq_count(uint32_t i) const {
  153. assert(i < pos.size());
  154. assert(pos[i] != -1);
  155. return seq[i].count();
  156. }
  157. // check if the cell contains seq_id
  158. bool seq_has(uint32_t i, llama_seq_id seq_id) const {
  159. assert(i < pos.size());
  160. assert(seq_id >= 0);
  161. return seq[i].test(seq_id);
  162. }
  163. // note: call only if the cell is not empty and the seq_id is not in the cell
  164. void seq_add(uint32_t i, llama_seq_id seq_id) {
  165. assert(i < pos.size());
  166. assert(pos[i] != -1);
  167. assert(!seq[i].test(seq_id));
  168. seq[i].set(seq_id);
  169. seq_pos[seq_id].insert(pos[i]);
  170. }
  171. // return the sequence id of this cell
  172. // note: call only for cells with exactly one sequence
  173. llama_seq_id seq_get(uint32_t i) const {
  174. assert(seq[i].count() == 1);
  175. for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
  176. if (seq[i].test(s)) {
  177. return s;
  178. }
  179. }
  180. return -1;
  181. }
  182. // the minimum position of sequence seq_id currently present in any of the cells
  183. // return -1 if the sequence is not present
  184. llama_pos seq_pos_min(llama_seq_id seq_id) const {
  185. assert(seq_id >= 0);
  186. assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
  187. if (seq_pos[seq_id].empty()) {
  188. return -1;
  189. }
  190. return *seq_pos[seq_id].begin();
  191. }
  192. // the maximum position of sequence seq_id currently present in any of the cells
  193. // return -1 if the sequence is not present
  194. llama_pos seq_pos_max(llama_seq_id seq_id) const {
  195. assert(seq_id >= 0);
  196. assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
  197. if (seq_pos[seq_id].empty()) {
  198. return -1;
  199. }
  200. return *seq_pos[seq_id].rbegin();
  201. }
  202. // note: call only if the cell is not empty
  203. llama_pos pos_get(uint32_t i) const {
  204. assert(i < pos.size());
  205. assert(pos[i] != -1);
  206. return pos[i];
  207. }
  208. // note: call only if the cell is not empty
  209. llama_pos get_shift(uint32_t i) const {
  210. assert(i < pos.size());
  211. assert(pos[i] != -1);
  212. return shift[i];
  213. }
  214. // check if a cell is not empty and its position is within [p0, p1)
  215. bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
  216. assert(i < pos.size());
  217. return pos[i] >= p0 && pos[i] < p1;
  218. }
  219. // set the position of an empty cell
  220. // does not modify "has_shift"
  221. // note: call only if the cell is empty
  222. void pos_set(uint32_t i, llama_pos p) {
  223. assert(i < pos.size());
  224. assert(pos[i] == -1);
  225. assert(seq[i].none());
  226. pos[i] = p;
  227. used.insert(i);
  228. }
  229. // pos[i] = pos[i] + d
  230. // sets "has_shift" to true
  231. // note: call only if the cell is not empty
  232. bool pos_add(uint32_t i, llama_pos d) {
  233. assert(i < pos.size());
  234. assert(pos[i] != -1);
  235. seq_pos_rm(i);
  236. pos[i] += d;
  237. shift[i] += d;
  238. seq_pos_add(i);
  239. has_shift = true;
  240. if (pos[i] < 0) {
  241. seq_pos_rm(i);
  242. seq[i].reset();
  243. pos[i] = -1;
  244. used.erase(i);
  245. return true;
  246. }
  247. return false;
  248. }
  249. // pos[i] = pos[i] / d
  250. // sets "has_shift" to true
  251. // note: call only if the cell is not empty
  252. void pos_div(uint32_t i, int d) {
  253. assert(i < pos.size());
  254. assert(pos[i] != -1);
  255. const llama_pos p_old = pos[i];
  256. seq_pos_rm(i);
  257. pos[i] /= d;
  258. shift[i] += p_old - pos[i];
  259. seq_pos_add(i);
  260. has_shift = true;
  261. }
  262. private:
  263. bool has_shift = false;
  264. // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
  265. std::set<uint32_t> used;
  266. std::vector<llama_pos> pos;
  267. // this array accumulates any applied shifts to the pos array since the last reset_shift() call
  268. // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
  269. //
  270. // cells.pos_add(x, shift_x);
  271. // cells.pos_div(y, shift_y);
  272. // ...
  273. //
  274. // if (cells.has_shift()) {
  275. // for (int i = 0; i < n; ++i) {
  276. // auto shift_i = cells.get_shift(i);
  277. // ...
  278. // }
  279. // cells.reset_shift();
  280. // }
  281. //
  282. std::vector<llama_pos> shift;
  283. using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
  284. // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
  285. std::vector<bits_t> seq;
  286. // the set seq_pos[s] tells us which positions are currently present for sequence s
  287. // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
  288. std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
  289. // helper functions for updating `seq_pos`, once cell at a time:
  290. // remove cell i
  291. void seq_pos_rm(uint32_t i) {
  292. for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
  293. if (seq[i].test(s)) {
  294. seq_pos[s].erase(pos[i]);
  295. }
  296. }
  297. }
  298. // add cell i
  299. void seq_pos_add(uint32_t i) {
  300. for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
  301. if (seq[i].test(s)) {
  302. seq_pos[s].insert(pos[i]);
  303. }
  304. }
  305. }
  306. };