llama-kv-cache.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-io.h"
  4. #include "llama-graph.h"
  5. #include "llama-memory.h"
  6. #include "ggml-cpp.h"
  7. #include <set>
  8. #include <unordered_map>
  9. #include <vector>
  10. struct llama_cparams;
  11. struct llama_hparams;
  12. struct llama_ubatch;
  13. struct llama_sbatch;
  14. struct llama_model;
  15. struct llama_context;
  16. struct llama_kv_cache : public llama_memory_i {
  17. virtual ~llama_kv_cache() = default;
  18. // call if batch processing fails - restores the cache state
  19. virtual void restore() = 0;
  20. // call after successful batch processing - clears any pending state
  21. virtual void commit() = 0;
  22. // process any pending defrag/shift/etc. operations
  23. // optionally call once before processing a new batch
  24. virtual bool update(llama_context & lctx) = 0;
  25. // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
  26. virtual void defrag_sched(float thold) = 0;
  27. // simulate full cache, used for allocating worst-case compute buffers
  28. virtual void set_full() = 0;
  29. //
  30. // batch processing
  31. //
  32. // =============================================================================================================
  33. // TODO: refactor and simplify this
  34. virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
  35. // different KV caches require different batch splitting strategies
  36. virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
  37. // find an empty slot of size "n_tokens" in the cache
  38. virtual bool find_slot(const llama_ubatch & batch) = 0;
  39. // =============================================================================================================
  40. // getters
  41. virtual int32_t get_n_tokens() const = 0;
  42. virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
  43. virtual llama_pos get_pos_max() const = 0;
  44. virtual bool get_can_shift() const = 0;
  45. bool get_can_edit() const override { return get_can_shift(); }
  46. //
  47. // state write/read
  48. //
  49. virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
  50. virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
  51. };
  52. //
  53. // llama_kv_cache_guard
  54. //
  55. struct llama_kv_cache_guard {
  56. llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
  57. ~llama_kv_cache_guard() {
  58. kv->restore();
  59. }
  60. void commit() {
  61. kv->commit();
  62. }
  63. private:
  64. llama_kv_cache * kv;
  65. };
  66. //
  67. // llama_kv_cache_unified
  68. //
  69. class llama_kv_cache_unified : public llama_kv_cache {
  70. public:
  71. static uint32_t get_padding(const llama_cparams & cparams);
  72. // this callback is used to filter out layers that should not be included in the cache
  73. using layer_filter_cb = std::function<bool(int32_t il)>;
  74. llama_kv_cache_unified(
  75. const llama_model & model,
  76. layer_filter_cb && filter,
  77. ggml_type type_k,
  78. ggml_type type_v,
  79. bool v_trans,
  80. bool offload,
  81. uint32_t kv_size,
  82. uint32_t padding,
  83. uint32_t n_swa,
  84. llama_swa_type swa_type);
  85. ~llama_kv_cache_unified() = default;
  86. //
  87. // llama_memory_i
  88. //
  89. void clear() override;
  90. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  91. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  92. void seq_keep(llama_seq_id seq_id) override;
  93. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
  94. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  95. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  96. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  97. //
  98. // llama_kv_cache
  99. //
  100. void restore() override;
  101. void commit() override;
  102. bool update(llama_context & ctx) override;
  103. void defrag_sched(float thold) override;
  104. void set_full() override;
  105. llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
  106. llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
  107. // updates the cache head
  108. // Note: On success, it's important that cache.head points
  109. // to the first cell of the slot.
  110. bool find_slot(const llama_ubatch & batch) override;
  111. int32_t get_n_tokens() const override;
  112. int32_t get_used_cells() const override;
  113. // TODO: better data structures to reduce the cost of this operation
  114. llama_pos get_pos_max() const override;
  115. bool get_can_shift() const override;
  116. // state write/load
  117. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
  118. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
  119. //
  120. // llama_kv_cache_unified specific API
  121. //
  122. uint32_t get_n() const;
  123. uint32_t get_size() const;
  124. // get views of the current state of the cache
  125. ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
  126. ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
  127. // store k_cur and v_cur in the cache based on the current head location
  128. ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
  129. ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
  130. void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
  131. void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
  132. void set_input_k_shift (ggml_tensor * dst) const;
  133. void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
  134. private:
  135. const llama_model & model;
  136. const llama_hparams & hparams;
  137. struct kv_cell {
  138. llama_pos pos = -1;
  139. llama_pos delta = 0;
  140. // TODO: replace with bitset uint64_t
  141. std::set<llama_seq_id> seq_id;
  142. bool has_seq_id(const llama_seq_id & id) const {
  143. return seq_id.find(id) != seq_id.end();
  144. }
  145. bool is_empty() const {
  146. return seq_id.empty();
  147. }
  148. bool is_same_seq(const kv_cell & other) const {
  149. return seq_id == other.seq_id;
  150. }
  151. };
  152. struct kv_layer {
  153. // layer index in the model
  154. // note: can be different from the layer index in the KV cache
  155. uint32_t il;
  156. ggml_tensor * k;
  157. ggml_tensor * v;
  158. };
  159. bool has_shift = false;
  160. bool do_defrag = false;
  161. bool v_trans = true; // the value tensor is transposed
  162. uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
  163. uint32_t size = 0; // total number of cells, shared across all sequences
  164. uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
  165. // computed before each graph build
  166. uint32_t n = 0;
  167. // required padding
  168. uint32_t padding = 1;
  169. ggml_type type_k = GGML_TYPE_F16;
  170. ggml_type type_v = GGML_TYPE_F16;
  171. // SWA
  172. uint32_t n_swa = 0;
  173. llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
  174. std::vector<ggml_context_ptr> ctxs;
  175. std::vector<ggml_backend_buffer_ptr> bufs;
  176. std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
  177. std::vector<kv_layer> layers;
  178. // model layer id -> KV cache layer id
  179. std::unordered_map<int32_t, int32_t> map_layer_ids;
  180. // recovery information used to restore the KV cells to their original state in case of a failure
  181. struct {
  182. void clear() {
  183. cells.clear();
  184. }
  185. std::unordered_map<uint32_t, kv_cell> cells;
  186. } recovery;
  187. // defrag
  188. struct {
  189. std::vector<uint32_t> ids;
  190. } defrag_info;
  191. // return true if cells have been moved
  192. bool defrag_prepare(int32_t n_max_nodes);
  193. // find how many cells are currently in use
  194. uint32_t cell_max() const;
  195. size_t total_size() const;
  196. size_t size_k_bytes() const;
  197. size_t size_v_bytes() const;
  198. bool is_masked_swa(llama_pos p0, llama_pos p1) const;
  199. ggml_tensor * build_rope_shift(
  200. const llama_cparams & cparams,
  201. ggml_context * ctx,
  202. ggml_tensor * cur,
  203. ggml_tensor * shift,
  204. ggml_tensor * factors,
  205. float freq_base,
  206. float freq_scale) const;
  207. llm_graph_result_ptr build_graph_shift(
  208. const llama_cparams & cparams,
  209. ggml_context * ctx,
  210. ggml_cgraph * gf) const;
  211. llm_graph_result_ptr build_graph_defrag(
  212. const llama_cparams & cparams,
  213. ggml_context * ctx,
  214. ggml_cgraph * gf) const;
  215. void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
  216. void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
  217. bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
  218. bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
  219. };
  220. //
  221. // llama_kv_cache_unified_iswa
  222. //
  223. // utilizes two instances of llama_kv_cache_unified
  224. // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
  225. // upon successful commit, the SWA cache removes old tokens outside the n_swa window
  226. class llama_kv_cache_unified_iswa : public llama_kv_cache {
  227. public:
  228. llama_kv_cache_unified_iswa(
  229. const llama_model & model,
  230. ggml_type type_k,
  231. ggml_type type_v,
  232. bool v_trans,
  233. bool offload,
  234. uint32_t kv_size,
  235. bool swa_full,
  236. uint32_t n_seq_max,
  237. uint32_t n_batch,
  238. uint32_t padding);
  239. ~llama_kv_cache_unified_iswa() = default;
  240. //
  241. // llama_memory_i
  242. //
  243. void clear() override;
  244. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  245. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  246. void seq_keep(llama_seq_id seq_id) override;
  247. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
  248. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  249. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  250. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  251. //
  252. // llama_kv_cache
  253. //
  254. void restore() override;
  255. void commit() override;
  256. bool update(llama_context & ctx) override;
  257. void defrag_sched(float thold) override;
  258. void set_full() override;
  259. llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
  260. llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
  261. bool find_slot(const llama_ubatch & batch) override;
  262. int32_t get_n_tokens() const override;
  263. int32_t get_used_cells() const override;
  264. // TODO: better data structures to reduce the cost of this operation
  265. llama_pos get_pos_max() const override;
  266. bool get_can_shift() const override;
  267. // state write/load
  268. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
  269. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
  270. //
  271. // llama_kv_cache_unified_iswa specific API
  272. //
  273. llama_kv_cache_unified * get_kv_base() const;
  274. llama_kv_cache_unified * get_kv_swa () const;
  275. private:
  276. const llama_hparams & hparams;
  277. bool do_prune = true;
  278. struct {
  279. struct entry {
  280. llama_pos pmin;
  281. llama_pos pmax;
  282. };
  283. void clear() {
  284. pos.clear();
  285. }
  286. // used to perform SWA pruning of old tokens
  287. std::unordered_map<llama_seq_id, entry> pos;
  288. } pending;
  289. std::unique_ptr<llama_kv_cache_unified> kv_base;
  290. std::unique_ptr<llama_kv_cache_unified> kv_swa;
  291. };
  292. //
  293. // llama_kv_cache_recurrent
  294. //
  295. class llama_kv_cache_recurrent : public llama_kv_cache {
  296. public:
  297. struct kv_cell {
  298. llama_pos pos = -1;
  299. int32_t src = -1; // used to copy states
  300. int32_t tail = -1;
  301. std::set<llama_seq_id> seq_id;
  302. bool has_seq_id(const llama_seq_id & id) const {
  303. return seq_id.find(id) != seq_id.end();
  304. }
  305. bool is_empty() const {
  306. return seq_id.empty();
  307. }
  308. bool is_same_seq(const kv_cell & other) const {
  309. return seq_id == other.seq_id;
  310. }
  311. };
  312. llama_kv_cache_recurrent(
  313. const llama_model & model,
  314. ggml_type type_k,
  315. ggml_type type_v,
  316. bool offload,
  317. uint32_t kv_size);
  318. ~llama_kv_cache_recurrent() = default;
  319. //
  320. // llama_memory_i
  321. //
  322. void clear() override;
  323. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  324. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  325. void seq_keep(llama_seq_id seq_id) override;
  326. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
  327. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  328. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  329. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  330. //
  331. // llama_kv_cache
  332. //
  333. void restore() override;
  334. void commit() override;
  335. bool update(llama_context & lctx) override;
  336. void defrag_sched(float thold) override;
  337. void set_full() override;
  338. llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
  339. llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
  340. bool find_slot(const llama_ubatch & batch) override;
  341. int32_t get_n_tokens() const override;
  342. int32_t get_used_cells() const override;
  343. // TODO: better data structures to reduce the cost of this operation
  344. llama_pos get_pos_max() const override;
  345. bool get_can_shift() const override;
  346. // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
  347. int32_t s_copy(int i) const;
  348. float s_mask(int i) const;
  349. // state write/load
  350. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
  351. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
  352. uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
  353. uint32_t size = 0; // total number of cells, shared across all sequences
  354. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  355. // computed before each graph build
  356. uint32_t n = 0;
  357. std::vector<kv_cell> cells;
  358. std::vector<ggml_tensor *> k_l; // per layer
  359. std::vector<ggml_tensor *> v_l;
  360. private:
  361. //const llama_model & model;
  362. const llama_hparams & hparams;
  363. // commit/restore cache
  364. // TODO: rework for recurrent cache
  365. struct slot_range {
  366. uint32_t c0 = 0; // note: these are cell indices, not sequence positions
  367. uint32_t c1 = 0;
  368. };
  369. // pending cell updates that are not yet committed
  370. struct {
  371. std::vector<slot_range> ranges;
  372. } pending;
  373. ggml_type type_k = GGML_TYPE_F16;
  374. ggml_type type_v = GGML_TYPE_F16;
  375. std::vector<ggml_context_ptr> ctxs;
  376. std::vector<ggml_backend_buffer_ptr> bufs;
  377. // find how many cells are currently in use
  378. uint32_t cell_max() const;
  379. size_t total_size() const;
  380. size_t size_k_bytes() const;
  381. size_t size_v_bytes() const;
  382. void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
  383. void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
  384. bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
  385. bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
  386. };