llama-memory-hybrid.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #pragma once
  2. #include "llama-batch.h"
  3. #include "llama-graph.h"
  4. #include "llama-kv-cache.h"
  5. #include "llama-memory.h"
  6. #include "llama-memory-recurrent.h"
  7. #include <memory>
  8. #include <vector>
  9. //
  10. // llama_memory_hybrid
  11. //
  12. // utilizes instances of llama_memory_recurrent and llama_kv_cache to
  13. // support models where each layer may be either attention-based or recurrent
  14. class llama_memory_hybrid : public llama_memory_i {
  15. public:
  16. llama_memory_hybrid(
  17. const llama_model & model,
  18. /* attn */
  19. ggml_type type_k,
  20. ggml_type type_v,
  21. bool v_trans,
  22. uint32_t kv_size,
  23. uint32_t n_pad,
  24. uint32_t n_swa,
  25. llama_swa_type swa_type,
  26. /* recurrent */
  27. ggml_type type_r,
  28. ggml_type type_s,
  29. uint32_t rs_size,
  30. /* common */
  31. uint32_t n_seq_max,
  32. bool offload,
  33. bool unified,
  34. /* layer filters */
  35. const layer_filter_cb & filter_attn = nullptr,
  36. const layer_filter_cb & filter_recr = nullptr);
  37. ~llama_memory_hybrid() = default;
  38. //
  39. // llama_memory_i
  40. //
  41. llama_memory_context_ptr init_batch(
  42. llama_batch_allocr & balloc,
  43. uint32_t n_ubatch,
  44. bool embd_all) override;
  45. llama_memory_context_ptr init_full() override;
  46. llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
  47. bool get_can_shift() const override;
  48. void clear(bool data) override;
  49. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  50. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  51. void seq_keep(llama_seq_id seq_id) override;
  52. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  53. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  54. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  55. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  56. std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
  57. // state write/load
  58. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
  59. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
  60. //
  61. // llama_memory_hybrid specific API
  62. //
  63. llama_kv_cache * get_mem_attn() const;
  64. llama_memory_recurrent * get_mem_recr() const;
  65. private:
  66. const llama_hparams & hparams;
  67. const std::unique_ptr<llama_kv_cache> mem_attn;
  68. const std::unique_ptr<llama_memory_recurrent> mem_recr;
  69. };
  70. class llama_memory_hybrid_context : public llama_memory_context_i {
  71. public:
  72. using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
  73. // init failure
  74. explicit llama_memory_hybrid_context(llama_memory_status status);
  75. // init full
  76. explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
  77. // init update
  78. explicit llama_memory_hybrid_context(
  79. llama_memory_hybrid * mem,
  80. llama_context * lctx,
  81. bool optimize);
  82. // init success
  83. llama_memory_hybrid_context(
  84. llama_memory_hybrid * mem,
  85. slot_info_vec_t sinfos_attn,
  86. std::vector<llama_ubatch> ubatches);
  87. ~llama_memory_hybrid_context() = default;
  88. bool next() override;
  89. bool apply() override;
  90. llama_memory_status get_status() const override;
  91. const llama_ubatch & get_ubatch() const override;
  92. //
  93. // llama_memory_hybrid_context
  94. //
  95. const llama_kv_cache_context * get_attn() const;
  96. const llama_memory_recurrent_context * get_recr() const;
  97. private:
  98. // the index of the next ubatch to process
  99. size_t i_next = 0;
  100. std::vector<llama_ubatch> ubatches;
  101. const llama_memory_context_ptr ctx_attn;
  102. const llama_memory_context_ptr ctx_recr;
  103. const llama_memory_status status;
  104. };