llama-kv-cache.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-io.h"
  4. #include "llama-memory.h"
  5. #include "ggml-cpp.h"
  6. #include <functional>
  7. #include <set>
  8. #include <vector>
  9. struct llama_cparams;
  10. struct llama_hparams;
  11. struct llama_ubatch;
  12. struct llama_kv_cache : public llama_memory_i {
  13. using llama_memory_i::llama_memory_i;
  14. virtual int32_t get_n_tokens() const = 0;
  15. virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
  16. virtual bool get_can_shift() const = 0;
  17. bool get_can_edit() const override { return get_can_shift(); }
  18. };
  19. struct llama_kv_cell {
  20. llama_pos pos = -1;
  21. llama_pos delta = 0;
  22. int32_t src = -1; // used by recurrent state models to copy states
  23. int32_t tail = -1;
  24. std::set<llama_seq_id> seq_id;
  25. bool has_seq_id(const llama_seq_id & id) const {
  26. return seq_id.find(id) != seq_id.end();
  27. }
  28. bool is_empty() const {
  29. return seq_id.empty();
  30. }
  31. bool is_same_seq(const llama_kv_cell & other) const {
  32. return seq_id == other.seq_id;
  33. }
  34. };
  35. // a structure holds information about the slot found in llama_kv_cache_find_slot
  36. struct llama_kv_cache_slot_info {
  37. std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
  38. bool found = false; // the slot was found
  39. explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
  40. llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
  41. operator bool() const { return found; }
  42. };
  43. // ring-buffer of cached KV data
  44. // TODO: pimpl
  45. // TODO: add notion of max sequences
  46. class llama_kv_cache_unified : public llama_kv_cache {
  47. public:
  48. // can be used to query data from the model if needed
  49. struct callbacks {
  50. std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
  51. };
  52. llama_kv_cache_unified(
  53. const llama_hparams & hparams,
  54. callbacks cbs);
  55. virtual ~llama_kv_cache_unified() = default;
  56. // TODO: become constructor
  57. bool init(
  58. const llama_model & model, // TODO: do not reference the model
  59. const llama_cparams & cparams,
  60. ggml_type type_k,
  61. ggml_type type_v,
  62. uint32_t kv_size,
  63. bool offload);
  64. int32_t get_n_tokens() const override;
  65. uint32_t get_used_cells() const override;
  66. size_t total_size() const;
  67. // TODO: better data structures to reduce the cost of this operation
  68. llama_pos pos_max() const;
  69. void clear() override;
  70. void defrag() override;
  71. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  72. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  73. void seq_keep(llama_seq_id seq_id) override;
  74. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
  75. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  76. llama_pos seq_pos_max(llama_seq_id seq_id) override;
  77. bool get_can_shift() const override;
  78. // find an empty slot of size "n_tokens" in the cache
  79. // updates the cache head
  80. // returns a structure holding information about the slot found
  81. // Note: On success, it's important that cache.head points
  82. // to the first cell of the slot.
  83. llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
  84. // TODO: maybe not needed
  85. uint32_t get_padding(const llama_cparams & cparams) const;
  86. // find how many cells are currently in use
  87. uint32_t cell_max() const;
  88. size_t size_k_bytes() const;
  89. size_t size_v_bytes() const;
  90. // defrag
  91. struct {
  92. std::vector<uint32_t> ids;
  93. } defrag_info;
  94. // return true if cells have been moved
  95. bool defrag_prepare(int32_t n_max_nodes);
  96. // state save/load
  97. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
  98. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
  99. // members
  100. const llama_hparams & hparams;
  101. callbacks cbs;
  102. bool has_shift = false;
  103. bool do_defrag = false;
  104. // TODO: remove this and implement llama_kv_cache_recurrent instead
  105. bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
  106. bool v_trans = true; // the value tensor is transposed
  107. bool can_shift = false;
  108. // Note: The value of head isn't only used to optimize searching
  109. // for a free KV slot. llama_decode_impl also uses it, so it
  110. // cannot be freely changed after a slot has been allocated.
  111. uint32_t head = 0;
  112. uint32_t size = 0;
  113. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  114. // computed before each graph build
  115. uint32_t n = 0;
  116. std::vector<llama_kv_cell> cells;
  117. std::vector<ggml_tensor *> k_l; // per layer
  118. std::vector<ggml_tensor *> v_l;
  119. private:
  120. ggml_type type_k = GGML_TYPE_F16;
  121. ggml_type type_v = GGML_TYPE_F16;
  122. std::vector<ggml_context_ptr> ctxs;
  123. std::vector<ggml_backend_buffer_ptr> bufs;
  124. void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
  125. void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
  126. bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
  127. bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
  128. };
  129. // TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
  130. //class llama_kv_cache_recurrent : public llama_kv_cache_unified {
  131. //public:
  132. // using llama_kv_cache_unified::llama_kv_cache_unified;
  133. //};
  134. //
  135. // kv cache restore
  136. //
  137. // saves the kv_cache state for future recovery.
  138. // used to rollback llama_kv_cache_find_slot changes.
  139. struct llama_kv_slot_restorer {
  140. struct llama_kv_cache_state {
  141. uint32_t head = 0;
  142. uint32_t n = 0;
  143. } old_state;
  144. // for non-recurrent models only
  145. // list of slots to restore
  146. std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
  147. bool do_restore = false;
  148. llama_kv_cache_unified & cache;
  149. explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
  150. old_state.head = cache.head;
  151. old_state.n = cache.n;
  152. }
  153. // saves a slot information for future restoration
  154. void save(const llama_kv_cache_slot_info & slot) {
  155. if (slot) {
  156. do_restore = true;
  157. if (slot.boundaries.first != slot.boundaries.second) {
  158. slot_boundaries.push_back(slot.boundaries);
  159. }
  160. }
  161. }
  162. // must be explicitly called to restore the kv_cache state
  163. // and rollback changes from all llama_kv_cache_find_slot calls
  164. void restore() {
  165. if (do_restore) {
  166. cache.head = old_state.head;
  167. cache.n = old_state.n;
  168. if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
  169. cache.seq_rm(-1, -1, -1);
  170. } else {
  171. for (auto & slot : slot_boundaries) {
  172. cache.seq_rm(-1, slot.first, slot.second);
  173. }
  174. }
  175. }
  176. }
  177. };
  178. // TODO: maybe become part of the public llama_kv_cache in the future
  179. int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
  180. int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
  181. void llama_kv_cache_clear(llama_kv_cache * kv);
  182. bool llama_kv_cache_seq_rm(
  183. llama_kv_cache * kv,
  184. llama_seq_id seq_id,
  185. llama_pos p0,
  186. llama_pos p1);
  187. void llama_kv_cache_seq_cp(
  188. llama_kv_cache * kv,
  189. llama_seq_id seq_id_src,
  190. llama_seq_id seq_id_dst,
  191. llama_pos p0,
  192. llama_pos p1);
  193. void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
  194. void llama_kv_cache_seq_add(
  195. llama_kv_cache * kv,
  196. llama_seq_id seq_id,
  197. llama_pos p0,
  198. llama_pos p1,
  199. llama_pos delta);
  200. void llama_kv_cache_seq_div(
  201. llama_kv_cache * kv,
  202. llama_seq_id seq_id,
  203. llama_pos p0,
  204. llama_pos p1,
  205. int d);
  206. llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
  207. void llama_kv_cache_defrag(llama_kv_cache * kv);
  208. bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
  209. //
  210. // kv cache view
  211. //
  212. llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
  213. void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);