llama-kv-cache-recurrent.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #pragma once
  2. #include "llama-batch.h"
  3. #include "llama-graph.h"
  4. #include "llama-memory.h"
  5. #include <set>
  6. #include <vector>
  7. //
  8. // llama_kv_cache_recurrent
  9. //
  10. // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
  11. // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
  12. class llama_kv_cache_recurrent : public llama_memory_i {
  13. public:
  14. llama_kv_cache_recurrent(
  15. const llama_model & model,
  16. ggml_type type_k,
  17. ggml_type type_v,
  18. bool offload,
  19. uint32_t kv_size,
  20. uint32_t n_seq_max);
  21. ~llama_kv_cache_recurrent() = default;
  22. //
  23. // llama_memory_i
  24. //
  25. llama_memory_state_ptr init_batch(
  26. const llama_batch & batch,
  27. uint32_t n_ubatch,
  28. bool embd_pooled) override;
  29. llama_memory_state_ptr init_full() override;
  30. llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
  31. void clear(bool data) override;
  32. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  33. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  34. void seq_keep(llama_seq_id seq_id) override;
  35. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  36. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  37. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  38. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  39. bool prepare(const std::vector<llama_ubatch> & ubatches);
  40. // find a contiguous slot of kv cells and emplace the ubatch there
  41. bool find_slot(const llama_ubatch & ubatch);
  42. bool get_can_shift() const override;
  43. // state write/load
  44. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
  45. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
  46. uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
  47. uint32_t size = 0; // total number of cells, shared across all sequences
  48. uint32_t used = 0; // used cells (i.e. at least one seq_id)
  49. // computed before each graph build
  50. uint32_t n = 0;
  51. // first zero-ed state
  52. int32_t rs_z = -1;
  53. // TODO: optimize for recurrent state needs
  54. struct kv_cell {
  55. llama_pos pos = -1;
  56. int32_t src = -1; // used to know where states should be copied from
  57. int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
  58. int32_t tail = -1;
  59. std::set<llama_seq_id> seq_id;
  60. bool has_seq_id(const llama_seq_id & id) const {
  61. return seq_id.find(id) != seq_id.end();
  62. }
  63. bool is_empty() const {
  64. return seq_id.empty();
  65. }
  66. bool is_same_seq(const kv_cell & other) const {
  67. return seq_id == other.seq_id;
  68. }
  69. };
  70. std::vector<kv_cell> cells;
  71. std::vector<ggml_tensor *> k_l; // per layer
  72. std::vector<ggml_tensor *> v_l;
  73. private:
  74. //const llama_model & model;
  75. const llama_hparams & hparams;
  76. const uint32_t n_seq_max = 1;
  77. std::vector<ggml_context_ptr> ctxs;
  78. std::vector<ggml_backend_buffer_ptr> bufs;
  79. size_t total_size() const;
  80. size_t size_k_bytes() const;
  81. size_t size_v_bytes() const;
  82. void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
  83. void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
  84. bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
  85. bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
  86. };
  87. class llama_kv_cache_recurrent_state : public llama_memory_state_i {
  88. public:
  89. // used for errors
  90. llama_kv_cache_recurrent_state(llama_memory_status status);
  91. // used to create a full-cache state
  92. llama_kv_cache_recurrent_state(
  93. llama_memory_status status,
  94. llama_kv_cache_recurrent * kv);
  95. // used to create a state from a batch
  96. llama_kv_cache_recurrent_state(
  97. llama_memory_status status,
  98. llama_kv_cache_recurrent * kv,
  99. llama_sbatch sbatch,
  100. std::vector<llama_ubatch> ubatches);
  101. virtual ~llama_kv_cache_recurrent_state();
  102. //
  103. // llama_memory_state_i
  104. //
  105. bool next() override;
  106. bool apply() override;
  107. std::vector<int64_t> & out_ids() override;
  108. llama_memory_status get_status() const override;
  109. const llama_ubatch & get_ubatch() const override;
  110. //
  111. // llama_kv_cache_recurrent_state specific API
  112. //
  113. uint32_t get_n_kv() const;
  114. uint32_t get_head() const;
  115. int32_t get_rs_z() const;
  116. uint32_t get_size() const;
  117. ggml_tensor * get_k_l(int32_t il) const;
  118. ggml_tensor * get_v_l(int32_t il) const;
  119. int32_t s_copy(int i) const;
  120. private:
  121. const llama_memory_status status;
  122. llama_kv_cache_recurrent * kv;
  123. llama_sbatch sbatch;
  124. size_t i_next = 0;
  125. std::vector<llama_ubatch> ubatches;
  126. //
  127. // data needed for building the compute graph for the current ubatch:
  128. // TODO: extract all the state like `head` and `n` here
  129. //
  130. const bool is_full = false;
  131. };