llama-memory-hybrid.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #pragma once
  2. #include "llama-batch.h"
  3. #include "llama-graph.h"
  4. #include "llama-kv-cache-unified.h"
  5. #include "llama-memory.h"
  6. #include "llama-memory-recurrent.h"
  7. #include <memory>
  8. #include <vector>
  9. //
  10. // llama_memory_hybrid
  11. //
  12. // utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
  13. // support models where each layer may be either attention-based or recurrent
  14. class llama_memory_hybrid : public llama_memory_i {
  15. public:
  16. // this callback is used to filter out layers that should not be included in the cache
  17. using layer_filter_cb = std::function<bool(int32_t il)>;
  18. llama_memory_hybrid(
  19. const llama_model & model,
  20. /* attn */
  21. ggml_type type_k,
  22. ggml_type type_v,
  23. bool v_trans,
  24. uint32_t kv_size,
  25. uint32_t n_pad,
  26. uint32_t n_swa,
  27. llama_swa_type swa_type,
  28. /* recurrent */
  29. ggml_type type_r,
  30. ggml_type type_s,
  31. uint32_t rs_size,
  32. /* common */
  33. uint32_t n_seq_max,
  34. bool offload,
  35. bool unified,
  36. /* layer filters */
  37. layer_filter_cb && filter_attn = nullptr,
  38. layer_filter_cb && filter_recr = nullptr);
  39. ~llama_memory_hybrid() = default;
  40. //
  41. // llama_memory_i
  42. //
  43. llama_memory_context_ptr init_batch(
  44. llama_batch_allocr & balloc,
  45. uint32_t n_ubatch,
  46. bool embd_all) override;
  47. llama_memory_context_ptr init_full() override;
  48. llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
  49. bool get_can_shift() const override;
  50. void clear(bool data) override;
  51. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  52. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  53. void seq_keep(llama_seq_id seq_id) override;
  54. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  55. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  56. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  57. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  58. // state write/load
  59. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
  60. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
  61. //
  62. // llama_memory_hybrid specific API
  63. //
  64. llama_kv_cache_unified * get_mem_attn() const;
  65. llama_memory_recurrent * get_mem_recr() const;
  66. private:
  67. const llama_hparams & hparams;
  68. const std::unique_ptr<llama_kv_cache_unified> mem_attn;
  69. const std::unique_ptr<llama_memory_recurrent> mem_recr;
  70. };
  71. class llama_memory_hybrid_context : public llama_memory_context_i {
  72. public:
  73. using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
  74. // init failure
  75. explicit llama_memory_hybrid_context(llama_memory_status status);
  76. // init full
  77. explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
  78. // init update
  79. explicit llama_memory_hybrid_context(
  80. llama_memory_hybrid * mem,
  81. llama_context * lctx,
  82. bool optimize);
  83. // init success
  84. llama_memory_hybrid_context(
  85. llama_memory_hybrid * mem,
  86. slot_info_vec_t sinfos_attn,
  87. std::vector<llama_ubatch> ubatches);
  88. ~llama_memory_hybrid_context() = default;
  89. bool next() override;
  90. bool apply() override;
  91. llama_memory_status get_status() const override;
  92. const llama_ubatch & get_ubatch() const override;
  93. //
  94. // llama_memory_hybrid_context
  95. //
  96. const llama_kv_cache_unified_context * get_attn() const;
  97. const llama_memory_recurrent_context * get_recr() const;
  98. private:
  99. // the index of the next ubatch to process
  100. size_t i_next = 0;
  101. std::vector<llama_ubatch> ubatches;
  102. const llama_memory_context_ptr ctx_attn;
  103. const llama_memory_context_ptr ctx_recr;
  104. const llama_memory_status status;
  105. };