llama-kv-cache-iswa.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #pragma once
  2. #include "llama-kv-cache.h"
  3. #include <vector>
  4. //
  5. // llama_kv_cache_iswa
  6. //
  7. // utilizes two instances of llama_kv_cache
  8. // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
  9. class llama_kv_cache_iswa : public llama_memory_i {
  10. public:
  11. llama_kv_cache_iswa(
  12. const llama_model & model,
  13. ggml_type type_k,
  14. ggml_type type_v,
  15. bool v_trans,
  16. bool offload,
  17. bool swa_full,
  18. bool unified,
  19. uint32_t kv_size,
  20. uint32_t n_seq_max,
  21. uint32_t n_ubatch,
  22. uint32_t n_pad,
  23. const layer_filter_cb & filter,
  24. const layer_reuse_cb & reuse);
  25. ~llama_kv_cache_iswa() = default;
  26. //
  27. // llama_memory_i
  28. //
  29. llama_memory_context_ptr init_batch(
  30. llama_batch_allocr & balloc,
  31. uint32_t n_ubatch,
  32. bool embd_all) override;
  33. llama_memory_context_ptr init_full() override;
  34. llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
  35. bool get_can_shift() const override;
  36. void clear(bool data) override;
  37. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  38. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  39. void seq_keep(llama_seq_id seq_id) override;
  40. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  41. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  42. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  43. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  44. // state write/load
  45. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
  46. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
  47. //
  48. // llama_kv_cache_iswa specific API
  49. //
  50. llama_kv_cache * get_base() const;
  51. llama_kv_cache * get_swa () const;
  52. private:
  53. const llama_hparams & hparams;
  54. const bool unified;
  55. std::unique_ptr<llama_kv_cache> kv_base;
  56. std::unique_ptr<llama_kv_cache> kv_swa;
  57. };
  58. class llama_kv_cache_iswa_context : public llama_memory_context_i {
  59. public:
  60. using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
  61. // used for errors
  62. llama_kv_cache_iswa_context(llama_memory_status status);
  63. // used to create a full-cache context
  64. llama_kv_cache_iswa_context(
  65. llama_kv_cache_iswa * kv);
  66. // used to create an update context
  67. llama_kv_cache_iswa_context(
  68. llama_kv_cache_iswa * kv,
  69. llama_context * lctx,
  70. bool optimize);
  71. // used to create a batch processing context from a batch
  72. llama_kv_cache_iswa_context(
  73. llama_kv_cache_iswa * kv,
  74. slot_info_vec_t sinfos_base,
  75. slot_info_vec_t sinfos_swa,
  76. std::vector<llama_ubatch> ubatches);
  77. virtual ~llama_kv_cache_iswa_context();
  78. //
  79. // llama_memory_context_i
  80. //
  81. bool next() override;
  82. bool apply() override;
  83. llama_memory_status get_status() const override;
  84. const llama_ubatch & get_ubatch() const override;
  85. //
  86. // llama_kv_cache_iswa_context specific API
  87. //
  88. const llama_kv_cache_context * get_base() const;
  89. const llama_kv_cache_context * get_swa() const;
  90. private:
  91. //llama_kv_cache_iswa * kv;
  92. // the index of the next ubatch to process
  93. size_t i_next = 0;
  94. std::vector<llama_ubatch> ubatches;
  95. const llama_memory_context_ptr ctx_base;
  96. const llama_memory_context_ptr ctx_swa;
  97. const llama_memory_status status;
  98. };