llama-kv-cache-iswa.h 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #include "llama-kv-cache.h"
  3. #include <vector>
  4. //
  5. // llama_kv_cache_iswa
  6. //
  7. // utilizes two instances of llama_kv_cache
  8. // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
  9. class llama_kv_cache_iswa : public llama_memory_i {
  10. public:
  11. llama_kv_cache_iswa(
  12. const llama_model & model,
  13. ggml_type type_k,
  14. ggml_type type_v,
  15. bool v_trans,
  16. bool offload,
  17. bool swa_full,
  18. bool unified,
  19. uint32_t kv_size,
  20. uint32_t n_seq_max,
  21. uint32_t n_ubatch,
  22. uint32_t n_pad,
  23. const layer_filter_cb & filter,
  24. const layer_reuse_cb & reuse);
  25. ~llama_kv_cache_iswa() = default;
  26. //
  27. // llama_memory_i
  28. //
  29. llama_memory_context_ptr init_batch(
  30. llama_batch_allocr & balloc,
  31. uint32_t n_ubatch,
  32. bool embd_all) override;
  33. llama_memory_context_ptr init_full() override;
  34. llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
  35. bool get_can_shift() const override;
  36. void clear(bool data) override;
  37. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  38. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  39. void seq_keep(llama_seq_id seq_id) override;
  40. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  41. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  42. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  43. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  44. std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
  45. // state write/load
  46. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
  47. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
  48. //
  49. // llama_kv_cache_iswa specific API
  50. //
  51. llama_kv_cache * get_base() const;
  52. llama_kv_cache * get_swa () const;
  53. private:
  54. const llama_hparams & hparams;
  55. const bool unified;
  56. std::unique_ptr<llama_kv_cache> kv_base;
  57. std::unique_ptr<llama_kv_cache> kv_swa;
  58. };
  59. class llama_kv_cache_iswa_context : public llama_memory_context_i {
  60. public:
  61. using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
  62. // used for errors
  63. llama_kv_cache_iswa_context(llama_memory_status status);
  64. // used to create a full-cache context
  65. llama_kv_cache_iswa_context(
  66. llama_kv_cache_iswa * kv);
  67. // used to create an update context
  68. llama_kv_cache_iswa_context(
  69. llama_kv_cache_iswa * kv,
  70. llama_context * lctx,
  71. bool optimize);
  72. // used to create a batch processing context from a batch
  73. llama_kv_cache_iswa_context(
  74. llama_kv_cache_iswa * kv,
  75. slot_info_vec_t sinfos_base,
  76. slot_info_vec_t sinfos_swa,
  77. std::vector<llama_ubatch> ubatches);
  78. virtual ~llama_kv_cache_iswa_context();
  79. //
  80. // llama_memory_context_i
  81. //
  82. bool next() override;
  83. bool apply() override;
  84. llama_memory_status get_status() const override;
  85. const llama_ubatch & get_ubatch() const override;
  86. //
  87. // llama_kv_cache_iswa_context specific API
  88. //
  89. const llama_kv_cache_context * get_base() const;
  90. const llama_kv_cache_context * get_swa() const;
  91. private:
  92. //llama_kv_cache_iswa * kv;
  93. // the index of the next ubatch to process
  94. size_t i_next = 0;
  95. std::vector<llama_ubatch> ubatches;
  96. const llama_memory_context_ptr ctx_base;
  97. const llama_memory_context_ptr ctx_swa;
  98. const llama_memory_status status;
  99. };