llama-memory-hybrid.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #include "llama-batch.h"
  3. #include "llama-graph.h"
  4. #include "llama-kv-cache.h"
  5. #include "llama-memory.h"
  6. #include "llama-memory-recurrent.h"
  7. #include <memory>
  8. #include <vector>
  9. //
  10. // llama_memory_hybrid
  11. //
  12. // utilizes instances of llama_memory_recurrent and llama_kv_cache to
  13. // support models where each layer may be either attention-based or recurrent
  14. class llama_memory_hybrid : public llama_memory_i {
  15. public:
  16. llama_memory_hybrid(
  17. const llama_model & model,
  18. /* attn */
  19. ggml_type type_k,
  20. ggml_type type_v,
  21. bool v_trans,
  22. uint32_t kv_size,
  23. uint32_t n_pad,
  24. uint32_t n_swa,
  25. llama_swa_type swa_type,
  26. /* recurrent */
  27. ggml_type type_r,
  28. ggml_type type_s,
  29. uint32_t rs_size,
  30. /* common */
  31. uint32_t n_seq_max,
  32. bool offload,
  33. bool unified,
  34. /* layer filters */
  35. const layer_filter_cb & filter_attn = nullptr,
  36. const layer_filter_cb & filter_recr = nullptr);
  37. ~llama_memory_hybrid() = default;
  38. //
  39. // llama_memory_i
  40. //
  41. llama_memory_context_ptr init_batch(
  42. llama_batch_allocr & balloc,
  43. uint32_t n_ubatch,
  44. bool embd_all) override;
  45. llama_memory_context_ptr init_full() override;
  46. llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
  47. bool get_can_shift() const override;
  48. void clear(bool data) override;
  49. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  50. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  51. void seq_keep(llama_seq_id seq_id) override;
  52. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  53. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  54. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  55. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  56. // state write/load
  57. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
  58. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
  59. //
  60. // llama_memory_hybrid specific API
  61. //
  62. llama_kv_cache * get_mem_attn() const;
  63. llama_memory_recurrent * get_mem_recr() const;
  64. private:
  65. const llama_hparams & hparams;
  66. const std::unique_ptr<llama_kv_cache> mem_attn;
  67. const std::unique_ptr<llama_memory_recurrent> mem_recr;
  68. };
  69. class llama_memory_hybrid_context : public llama_memory_context_i {
  70. public:
  71. using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
  72. // init failure
  73. explicit llama_memory_hybrid_context(llama_memory_status status);
  74. // init full
  75. explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
  76. // init update
  77. explicit llama_memory_hybrid_context(
  78. llama_memory_hybrid * mem,
  79. llama_context * lctx,
  80. bool optimize);
  81. // init success
  82. llama_memory_hybrid_context(
  83. llama_memory_hybrid * mem,
  84. slot_info_vec_t sinfos_attn,
  85. std::vector<llama_ubatch> ubatches);
  86. ~llama_memory_hybrid_context() = default;
  87. bool next() override;
  88. bool apply() override;
  89. llama_memory_status get_status() const override;
  90. const llama_ubatch & get_ubatch() const override;
  91. //
  92. // llama_memory_hybrid_context
  93. //
  94. const llama_kv_cache_context * get_attn() const;
  95. const llama_memory_recurrent_context * get_recr() const;
  96. private:
  97. // the index of the next ubatch to process
  98. size_t i_next = 0;
  99. std::vector<llama_ubatch> ubatches;
  100. const llama_memory_context_ptr ctx_attn;
  101. const llama_memory_context_ptr ctx_recr;
  102. const llama_memory_status status;
  103. };