llama-kv-cache.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-io.h"
  4. #include "llama-memory.h"
  5. #include "ggml-cpp.h"
  6. #include <functional>
  7. #include <set>
  8. #include <vector>
  9. struct llama_cparams;
  10. struct llama_hparams;
  11. struct llama_ubatch;
  12. struct llama_kv_cache : public llama_memory_i {
  13. using llama_memory_i::llama_memory_i;
  14. virtual void restore() = 0; // call if batch processing fails - restores the cache state
  15. virtual void commit() = 0; // call after successful batch processing - clears any pending state
  16. virtual int32_t get_n_tokens() const = 0;
  17. virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
  18. virtual bool get_can_shift() const = 0;
  19. bool get_can_edit() const override { return get_can_shift(); }
  20. };
  21. struct llama_kv_cache_guard {
  22. llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
  23. ~llama_kv_cache_guard() {
  24. kv->restore();
  25. }
  26. void commit() {
  27. kv->commit();
  28. }
  29. private:
  30. llama_kv_cache * kv;
  31. };
  32. struct llama_kv_cell {
  33. llama_pos pos = -1;
  34. llama_pos delta = 0;
  35. int32_t src = -1; // used by recurrent state models to copy states
  36. int32_t tail = -1;
  37. std::set<llama_seq_id> seq_id;
  38. bool has_seq_id(const llama_seq_id & id) const {
  39. return seq_id.find(id) != seq_id.end();
  40. }
  41. bool is_empty() const {
  42. return seq_id.empty();
  43. }
  44. bool is_same_seq(const llama_kv_cell & other) const {
  45. return seq_id == other.seq_id;
  46. }
  47. };
  48. // ring-buffer of cached KV data
  49. // TODO: pimpl
  50. // TODO: add notion of max sequences
  51. class llama_kv_cache_unified : public llama_kv_cache {
  52. public:
  53. // can be used to query data from the model if needed
  54. struct callbacks {
  55. std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
  56. };
  57. llama_kv_cache_unified(
  58. const llama_hparams & hparams,
  59. callbacks cbs);
  60. virtual ~llama_kv_cache_unified() = default;
  61. // TODO: become constructor
  62. bool init(
  63. const llama_model & model, // TODO: do not reference the model
  64. const llama_cparams & cparams,
  65. ggml_type type_k,
  66. ggml_type type_v,
  67. uint32_t kv_size,
  68. bool offload);
  69. int32_t get_n_tokens() const override;
  70. int32_t get_used_cells() const override;
  71. size_t total_size() const;
  72. // TODO: better data structures to reduce the cost of this operation
  73. llama_pos pos_max() const;
  74. void clear() override;
  75. void defrag() override;
  76. virtual void restore() override;
  77. virtual void commit() override;
  78. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  79. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  80. void seq_keep(llama_seq_id seq_id) override;
  81. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
  82. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  83. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  84. bool get_can_shift() const override;
  85. // find an empty slot of size "n_tokens" in the cache
  86. // updates the cache head
  87. // Note: On success, it's important that cache.head points
  88. // to the first cell of the slot.
  89. bool find_slot(const llama_ubatch & batch);
  90. // TODO: maybe not needed
  91. uint32_t get_padding(const llama_cparams & cparams) const;
  92. // find how many cells are currently in use
  93. uint32_t cell_max() const;
  94. size_t size_k_bytes() const;
  95. size_t size_v_bytes() const;
  96. // defrag
  97. struct {
  98. std::vector<uint32_t> ids;
  99. } defrag_info;
  100. // return true if cells have been moved
  101. bool defrag_prepare(int32_t n_max_nodes);
  102. // commit/restore cache
  103. struct slot_range {
  104. uint32_t c0 = 0; // note: these are cell indices, not sequence positions
  105. uint32_t c1 = 0;
  106. };
  107. // pending cell updates that are not yet committed
  108. struct {
  109. std::vector<slot_range> ranges;
  110. } pending;
  111. // state write/load
  112. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
  113. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
  114. // members
  115. const llama_hparams & hparams;
  116. callbacks cbs;
  117. bool has_shift = false;
  118. bool do_defrag = false;
  119. // TODO: remove this and implement llama_kv_cache_recurrent instead
  120. bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
  121. bool v_trans = true; // the value tensor is transposed
  122. bool can_shift = false;
  123. // Note: The value of head isn't only used to optimize searching
  124. // for a free KV slot. llama_decode_impl also uses it, so it
  125. // cannot be freely changed after a slot has been allocated.
  126. uint32_t head = 0;
  127. uint32_t size = 0;
  128. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  129. // computed before each graph build
  130. uint32_t n = 0;
  131. std::vector<llama_kv_cell> cells;
  132. std::vector<ggml_tensor *> k_l; // per layer
  133. std::vector<ggml_tensor *> v_l;
  134. private:
  135. ggml_type type_k = GGML_TYPE_F16;
  136. ggml_type type_v = GGML_TYPE_F16;
  137. std::vector<ggml_context_ptr> ctxs;
  138. std::vector<ggml_backend_buffer_ptr> bufs;
  139. void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
  140. void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
  141. bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
  142. bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
  143. };
  144. // TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
  145. //class llama_kv_cache_recurrent : public llama_kv_cache_unified {
  146. //public:
  147. // using llama_kv_cache_unified::llama_kv_cache_unified;
  148. //};
  149. //
  150. // kv cache view
  151. //
  152. llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
  153. void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);