llama-kv-cells.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-cparams.h"
  4. #include <bitset>
  5. #include <cassert>
  6. #include <cstring>
  7. #include <map>
  8. #include <set>
  9. #include <vector>
  10. struct llama_kv_cell_ext {
  11. // 2D spatial positions, typically used for M-RoPE
  12. llama_pos x = 0;
  13. llama_pos y = 0;
  14. // return true if the current 2D spatial position is greater than other
  15. bool is_2d_gt(llama_pos ox, llama_pos oy) const {
  16. return (y > oy) || (y == oy && x > ox);
  17. }
  18. void reset() {
  19. static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
  20. memset(this, 0, sizeof(*this));
  21. }
  22. };
  23. // meta information about KV cells that can be part of multiple sequences at the same time
  24. // TODO: add unit tests
  25. class llama_kv_cells {
  26. public:
  27. void reset() {
  28. for (uint32_t i = 0; i < pos.size(); ++i) {
  29. pos[i] = -1;
  30. ext[i].reset();
  31. shift[i] = 0;
  32. seq[i].reset();
  33. }
  34. has_shift = false;
  35. used.clear();
  36. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  37. seq_pos[s].clear();
  38. }
  39. }
  40. void reset_shift() {
  41. has_shift = false;
  42. for (uint32_t i = 0; i < shift.size(); ++i) {
  43. shift[i] = 0;
  44. }
  45. }
  46. uint32_t size() const {
  47. return pos.size();
  48. }
  49. void resize(uint32_t n) {
  50. pos.resize(n);
  51. ext.resize(n);
  52. shift.resize(n);
  53. seq.resize(n);
  54. reset();
  55. }
  56. bool is_empty(uint32_t i) const {
  57. assert(i < pos.size());
  58. assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
  59. return pos[i] == -1;
  60. }
  61. uint32_t get_used() const {
  62. return used.size();
  63. }
  64. // the index of the first cell that is used
  65. // return 0 if no cells are used
  66. uint32_t used_min() const {
  67. return used.empty() ? 0 : *used.begin();
  68. }
  69. // the index of the last cell that is used + 1
  70. // return 0 if no cells are used
  71. uint32_t used_max_p1() const {
  72. return used.empty() ? 0 : *used.rbegin() + 1;
  73. }
  74. bool get_has_shift() const {
  75. return has_shift;
  76. }
  77. // move cell isrc to idst (used during defrag)
  78. //void mv(uint32_t isrc, uint32_t idst) {
  79. // assert(isrc < pos.size());
  80. // assert(idst < pos.size());
  81. // assert(pos[idst] == -1);
  82. // assert(pos[isrc] != -1);
  83. // pos [idst] = pos [isrc];
  84. // shift[idst] = shift[isrc];
  85. // seq [idst] = seq [isrc];
  86. // pos [isrc] = -1;
  87. // shift[isrc] = 0;
  88. // seq [isrc].reset();
  89. // used.erase (isrc);
  90. // used.insert(idst);
  91. //}
  92. // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
  93. llama_kv_cells cp(uint32_t i, uint32_t n) const {
  94. assert(i + n <= pos.size());
  95. llama_kv_cells res;
  96. res.resize(n);
  97. for (uint32_t j = 0; j < n; ++j) {
  98. const auto idx = i + j;
  99. res.pos[j] = pos[idx];
  100. res.ext[j] = ext[idx];
  101. res.seq[j] = seq[idx];
  102. assert(shift[idx] == 0);
  103. }
  104. return res;
  105. }
  106. // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
  107. llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
  108. llama_kv_cells res;
  109. res.resize(idxs.size());
  110. for (uint32_t j = 0; j < idxs.size(); ++j) {
  111. const auto idx = idxs[j];
  112. res.pos[j] = pos[idx];
  113. res.ext[j] = ext[idx];
  114. res.seq[j] = seq[idx];
  115. assert(shift[idx] == 0);
  116. }
  117. return res;
  118. }
  119. // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
  120. void set(uint32_t i, const llama_kv_cells & other) {
  121. assert(i + other.pos.size() <= pos.size());
  122. for (uint32_t j = 0; j < other.pos.size(); ++j) {
  123. const auto idx = i + j;
  124. if (pos[idx] == -1 && other.pos[j] != -1) {
  125. used.insert(i + j);
  126. }
  127. if (pos[idx] != -1 && other.pos[j] == -1) {
  128. used.erase(i + j);
  129. }
  130. if (pos[idx] != -1) {
  131. seq_pos_rm(i + j);
  132. }
  133. pos[idx] = other.pos[j];
  134. ext[idx] = other.ext[j];
  135. seq[idx] = other.seq[j];
  136. if (pos[idx] != -1) {
  137. seq_pos_add(i + j);
  138. }
  139. assert(shift[idx] == 0);
  140. }
  141. }
  142. // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
  143. void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
  144. assert(idxs.size() == other.pos.size());
  145. for (uint32_t j = 0; j < other.pos.size(); ++j) {
  146. const auto idx = idxs[j];
  147. if (pos[idx] == -1 && other.pos[j] != -1) {
  148. used.insert(idx);
  149. }
  150. if (pos[idx] != -1 && other.pos[j] == -1) {
  151. used.erase(idx);
  152. }
  153. if (pos[idx] != -1) {
  154. seq_pos_rm(idx);
  155. }
  156. pos[idx] = other.pos[j];
  157. ext[idx] = other.ext[j];
  158. seq[idx] = other.seq[j];
  159. if (pos[idx] != -1) {
  160. seq_pos_add(idx);
  161. }
  162. assert(shift[idx] == 0);
  163. }
  164. }
  165. // clear a non-empty cell
  166. void rm(uint32_t i) {
  167. assert(i < pos.size());
  168. assert(pos[i] != -1);
  169. seq_pos_rm(i);
  170. seq[i].reset();
  171. pos[i] = -1;
  172. ext[i].reset();
  173. shift[i] = 0;
  174. used.erase(i);
  175. }
  176. // note: call only if the cell has seq_id
  177. // return true if the cell becomes empty
  178. bool seq_rm(uint32_t i, llama_seq_id seq_id) {
  179. assert(i < pos.size());
  180. assert(seq[i].test(seq_id));
  181. assert(pos[i] != -1);
  182. assert(seq_id >= 0);
  183. seq[i].reset(seq_id);
  184. seq_pos_dec(seq_id, pos[i]);
  185. if (seq[i].none()) {
  186. pos[i] = -1;
  187. ext[i].reset();
  188. shift[i] = 0;
  189. used.erase(i);
  190. return true;
  191. }
  192. return false;
  193. }
  194. // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
  195. bool seq_keep(uint32_t i, llama_seq_id seq_id) {
  196. assert(i < pos.size());
  197. if (seq[i].test(seq_id)) {
  198. seq_pos_rm(i);
  199. seq[i].reset();
  200. seq[i].set(seq_id);
  201. seq_pos_inc(seq_id, pos[i]);
  202. return false;
  203. }
  204. if (seq[i].any()) {
  205. seq_pos_rm(i);
  206. seq[i].reset();
  207. pos[i] = -1;
  208. ext[i].reset();
  209. shift[i] = 0;
  210. used.erase(i);
  211. return true;
  212. }
  213. assert(pos[i] == -1);
  214. return false;
  215. }
  216. // number of different sequences in the cell
  217. int seq_count(uint32_t i) const {
  218. assert(i < pos.size());
  219. assert(pos[i] != -1);
  220. return seq[i].count();
  221. }
  222. // check if the cell contains seq_id
  223. bool seq_has(uint32_t i, llama_seq_id seq_id) const {
  224. assert(i < pos.size());
  225. assert(seq_id >= 0);
  226. return seq[i].test(seq_id);
  227. }
  228. // note: call only if the cell is not empty and the seq_id is not in the cell
  229. void seq_add(uint32_t i, llama_seq_id seq_id) {
  230. assert(i < pos.size());
  231. assert(pos[i] != -1);
  232. assert(!seq[i].test(seq_id));
  233. seq[i].set(seq_id);
  234. seq_pos_inc(seq_id, pos[i]);
  235. }
  236. // return the sequence id of this cell
  237. // note: call only for cells with exactly one sequence
  238. llama_seq_id seq_get(uint32_t i) const {
  239. assert(seq[i].count() == 1);
  240. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  241. if (seq[i].test(s)) {
  242. return s;
  243. }
  244. }
  245. return -1;
  246. }
  247. // the minimum position of sequence seq_id currently present in any of the cells
  248. // return -1 if the sequence is not present
  249. llama_pos seq_pos_min(llama_seq_id seq_id) const {
  250. assert(seq_id >= 0);
  251. assert(seq_id < LLAMA_MAX_SEQ);
  252. if (seq_pos[seq_id].empty()) {
  253. return -1;
  254. }
  255. assert(seq_pos[seq_id].begin()->second > 0);
  256. return seq_pos[seq_id].begin()->first;
  257. }
  258. // the maximum position of sequence seq_id currently present in any of the cells
  259. // return -1 if the sequence is not present
  260. llama_pos seq_pos_max(llama_seq_id seq_id) const {
  261. assert(seq_id >= 0);
  262. assert(seq_id < LLAMA_MAX_SEQ);
  263. if (seq_pos[seq_id].empty()) {
  264. return -1;
  265. }
  266. assert(seq_pos[seq_id].rbegin()->second > 0);
  267. return seq_pos[seq_id].rbegin()->first;
  268. }
  269. // note: call only if the cell is not empty
  270. llama_pos pos_get(uint32_t i) const {
  271. assert(i < pos.size());
  272. assert(pos[i] != -1);
  273. return pos[i];
  274. }
  275. const llama_kv_cell_ext & ext_get(uint32_t i) const {
  276. assert(i < pos.size());
  277. assert(pos[i] != -1);
  278. return ext[i];
  279. }
  280. // note: call only if the cell is not empty
  281. llama_pos get_shift(uint32_t i) const {
  282. assert(i < pos.size());
  283. assert(pos[i] != -1);
  284. return shift[i];
  285. }
  286. // check if a cell is not empty and its position is within [p0, p1)
  287. bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
  288. assert(i < pos.size());
  289. return pos[i] >= p0 && pos[i] < p1;
  290. }
  291. // set the position of an empty cell
  292. // does not modify "has_shift"
  293. // note: call only if the cell is empty
  294. void pos_set(uint32_t i, llama_pos p) {
  295. assert(i < pos.size());
  296. assert(pos[i] == -1);
  297. assert(seq[i].none());
  298. pos[i] = p;
  299. used.insert(i);
  300. }
  301. void ext_set(uint32_t i, llama_kv_cell_ext p) {
  302. assert(i < ext.size());
  303. ext[i] = p;
  304. }
  305. // pos[i] = pos[i] + d
  306. // sets "has_shift" to true
  307. // note: call only if the cell is not empty
  308. bool pos_add(uint32_t i, llama_pos d) {
  309. assert(i < pos.size());
  310. assert(pos[i] != -1);
  311. seq_pos_rm(i);
  312. pos[i] += d;
  313. shift[i] += d;
  314. has_shift = true;
  315. if (pos[i] < 0) {
  316. seq[i].reset();
  317. pos[i] = -1;
  318. shift[i] = 0;
  319. used.erase(i);
  320. return true;
  321. }
  322. seq_pos_add(i);
  323. return false;
  324. }
  325. // pos[i] = pos[i] / d
  326. // sets "has_shift" to true
  327. // note: call only if the cell is not empty
  328. void pos_div(uint32_t i, int d) {
  329. assert(i < pos.size());
  330. assert(pos[i] != -1);
  331. const llama_pos p_old = pos[i];
  332. seq_pos_rm(i);
  333. pos[i] /= d;
  334. shift[i] += p_old - pos[i];
  335. seq_pos_add(i);
  336. has_shift = true;
  337. }
  338. private:
  339. bool has_shift = false;
  340. // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
  341. std::set<uint32_t> used;
  342. std::vector<llama_pos> pos;
  343. // stores extra info per cell
  344. std::vector<llama_kv_cell_ext> ext;
  345. // this array accumulates any applied shifts to the pos array since the last reset_shift() call
  346. // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
  347. //
  348. // cells.pos_add(x, shift_x);
  349. // cells.pos_div(y, shift_y);
  350. // ...
  351. //
  352. // if (cells.has_shift()) {
  353. // for (int i = 0; i < n; ++i) {
  354. // auto shift_i = cells.get_shift(i);
  355. // ...
  356. // }
  357. // cells.reset_shift();
  358. // }
  359. //
  360. std::vector<llama_pos> shift;
  361. using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
  362. // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
  363. std::vector<seq_set_t> seq;
  364. // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
  365. // if the position p is not present, seq_pos[s][p] is not set
  366. // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
  367. //
  368. // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
  369. // - during performing a cache reuse via (rm + add)
  370. // - some vision models have input embeddings with repeating positions
  371. //
  372. std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
  373. // helper functions for updating `seq_pos`, once cell at a time:
  374. void seq_pos_dec(llama_seq_id s, llama_pos p) {
  375. auto it = seq_pos[s].find(p);
  376. assert(it != seq_pos[s].end());
  377. if (--it->second == 0) {
  378. seq_pos[s].erase(it);
  379. }
  380. }
  381. void seq_pos_inc(llama_seq_id s, llama_pos p) {
  382. seq_pos[s][p]++;
  383. }
  384. // remove cell i
  385. void seq_pos_rm(uint32_t i) {
  386. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  387. if (seq[i].test(s)) {
  388. seq_pos_dec(s, pos[i]);
  389. }
  390. }
  391. }
  392. // add cell i
  393. void seq_pos_add(uint32_t i) {
  394. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  395. if (seq[i].test(s)) {
  396. seq_pos_inc(s, pos[i]);
  397. }
  398. }
  399. }
  400. };