llama-kv-cache.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #pragma once
  2. #include "llama.h"
  3. #include "ggml-cpp.h"
  4. #include <set>
  5. #include <vector>
  6. struct llama_kv_cell {
  7. llama_pos pos = -1;
  8. llama_pos delta = 0;
  9. int32_t src = -1; // used by recurrent state models to copy states
  10. int32_t tail = -1;
  11. std::set<llama_seq_id> seq_id;
  12. bool has_seq_id(const llama_seq_id & id) const {
  13. return seq_id.find(id) != seq_id.end();
  14. }
  15. bool is_empty() const {
  16. return seq_id.empty();
  17. }
  18. bool is_same_seq(const llama_kv_cell & other) const {
  19. return seq_id == other.seq_id;
  20. }
  21. };
  22. // ring-buffer of cached KV data
  23. struct llama_kv_cache {
  24. bool has_shift = false;
  25. bool do_defrag = false;
  26. bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
  27. bool v_trans = true; // the value tensor is transposed
  28. bool can_shift = false;
  29. // Note: The value of head isn't only used to optimize searching
  30. // for a free KV slot. llama_decode_internal also uses it, so it
  31. // cannot be freely changed after a slot has been allocated.
  32. uint32_t head = 0;
  33. uint32_t size = 0;
  34. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  35. // computed before each graph build
  36. uint32_t n = 0;
  37. ggml_type type_k = GGML_TYPE_F16;
  38. ggml_type type_v = GGML_TYPE_F16;
  39. std::vector<llama_kv_cell> cells;
  40. std::vector<struct ggml_tensor *> k_l; // per layer
  41. std::vector<struct ggml_tensor *> v_l;
  42. std::vector<ggml_context_ptr> ctxs;
  43. std::vector<ggml_backend_buffer_ptr> bufs;
  44. size_t total_size() const {
  45. size_t size = 0;
  46. for (const auto & buf : bufs) {
  47. size += ggml_backend_buffer_get_size(buf.get());
  48. }
  49. return size;
  50. }
  51. // TODO: better data structures to reduce the cost of this operation
  52. llama_pos max_pos() const {
  53. llama_pos max_pos = -1;
  54. for (const auto & cell : cells) {
  55. max_pos = std::max(max_pos, cell.pos);
  56. }
  57. return max_pos;
  58. }
  59. };
  60. // a structure holds information about the slot found in llama_kv_cache_find_slot
  61. struct llama_kv_cache_slot_info {
  62. std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
  63. bool found = false; // the slot was found
  64. explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
  65. llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
  66. operator bool() const { return found; }
  67. };
  68. // TODO: maybe not needed
  69. uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
  70. bool llama_kv_cache_init(
  71. struct llama_kv_cache & cache,
  72. const llama_model & model,
  73. const llama_cparams & cparams,
  74. ggml_type type_k,
  75. ggml_type type_v,
  76. uint32_t kv_size,
  77. bool offload);
  78. // find an empty slot of size "n_tokens" in the cache
  79. // updates the cache head
  80. // returns a structure holding information about the slot found
  81. // Note: On success, it's important that cache.head points
  82. // to the first cell of the slot.
  83. struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
  84. struct llama_kv_cache & cache,
  85. const struct llama_ubatch & batch);
  86. // find how many cells are currently in use
  87. uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
  88. void llama_kv_cache_clear(struct llama_kv_cache & cache);
  89. bool llama_kv_cache_seq_rm(
  90. struct llama_kv_cache & cache,
  91. llama_seq_id seq_id,
  92. llama_pos p0,
  93. llama_pos p1);
  94. void llama_kv_cache_seq_cp(
  95. struct llama_kv_cache & cache,
  96. llama_seq_id seq_id_src,
  97. llama_seq_id seq_id_dst,
  98. llama_pos p0,
  99. llama_pos p1);
  100. void llama_kv_cache_seq_keep(
  101. struct llama_kv_cache & cache,
  102. llama_seq_id seq_id);
  103. void llama_kv_cache_seq_add(
  104. struct llama_kv_cache & cache,
  105. llama_seq_id seq_id,
  106. llama_pos p0,
  107. llama_pos p1,
  108. llama_pos delta);
  109. void llama_kv_cache_seq_div(
  110. struct llama_kv_cache & cache,
  111. llama_seq_id seq_id,
  112. llama_pos p0,
  113. llama_pos p1,
  114. int d);
  115. llama_pos llama_kv_cache_seq_pos_max(
  116. struct llama_kv_cache & cache,
  117. llama_seq_id seq_id);
  118. void llama_kv_cache_defrag(struct llama_kv_cache & cache);
  119. int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
  120. int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
  121. bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
  122. //
  123. // kv cache view
  124. //
  125. struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
  126. void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
  127. //
  128. // kv cache restore
  129. //
  130. // saves the kv_cache state for future recovery.
  131. // used to rollback llama_kv_cache_find_slot changes.
  132. struct llama_kv_slot_restorer {
  133. struct llama_kv_cache_state {
  134. uint32_t head = 0;
  135. uint32_t n = 0;
  136. } old_state;
  137. // for non-recurrent models only
  138. // list of slots to restore
  139. std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
  140. bool do_restore = false;
  141. explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
  142. old_state.head = cache.head;
  143. old_state.n = cache.n;
  144. }
  145. // saves a slot information for future restoration
  146. void save(const struct llama_kv_cache_slot_info & slot) {
  147. if (slot) {
  148. do_restore = true;
  149. if (slot.boundaries.first != slot.boundaries.second) {
  150. slot_boundaries.push_back(slot.boundaries);
  151. }
  152. }
  153. }
  154. // must be explicitly called to restore the kv_cache state
  155. // and rollback changes from all llama_kv_cache_find_slot calls
  156. void restore(struct llama_kv_cache & cache) {
  157. if (do_restore) {
  158. cache.head = old_state.head;
  159. cache.n = old_state.n;
  160. if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
  161. llama_kv_cache_seq_rm(cache, -1, -1, -1);
  162. } else {
  163. for (auto & slot : slot_boundaries) {
  164. llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
  165. }
  166. }
  167. }
  168. }
  169. };