llama-memory-hybrid.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #pragma once
  2. #include "llama-batch.h"
  3. #include "llama-graph.h"
  4. #include "llama-kv-cache-unified.h"
  5. #include "llama-memory.h"
  6. #include "llama-memory-recurrent.h"
  7. #include <memory>
  8. #include <vector>
  9. //
  10. // llama_memory_hybrid
  11. //
  12. // utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
  13. // support models where each layer may be either attention-based or recurrent
  14. class llama_memory_hybrid : public llama_memory_i {
  15. public:
  16. // this callback is used to filter out layers that should not be included in the cache
  17. using layer_filter_cb = std::function<bool(int32_t il)>;
  18. llama_memory_hybrid(
  19. const llama_model & model,
  20. /* attn */
  21. ggml_type type_k,
  22. ggml_type type_v,
  23. bool v_trans,
  24. uint32_t kv_size,
  25. uint32_t n_pad,
  26. uint32_t n_swa,
  27. llama_swa_type swa_type,
  28. /* recurrent */
  29. ggml_type type_r,
  30. ggml_type type_s,
  31. uint32_t rs_size,
  32. /* common */
  33. uint32_t n_seq_max,
  34. bool offload,
  35. /* layer filters */
  36. layer_filter_cb && filter_attn = nullptr,
  37. layer_filter_cb && filter_recr = nullptr);
  38. ~llama_memory_hybrid() = default;
  39. //
  40. // llama_memory_i
  41. //
  42. llama_memory_state_ptr init_batch(
  43. const llama_batch & batch,
  44. uint32_t n_ubatch,
  45. bool embd_pooled) override;
  46. llama_memory_state_ptr init_full() override;
  47. llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
  48. bool get_can_shift() const override;
  49. void clear(bool data) override;
  50. bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
  51. void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
  52. void seq_keep(llama_seq_id seq_id) override;
  53. void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
  54. void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
  55. llama_pos seq_pos_min(llama_seq_id seq_id) const override;
  56. llama_pos seq_pos_max(llama_seq_id seq_id) const override;
  57. // state write/load
  58. void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
  59. void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
  60. //
  61. // llama_memory_hybrid specific API
  62. //
  63. llama_kv_cache_unified * get_mem_attn() const;
  64. llama_memory_recurrent * get_mem_recr() const;
  65. private:
  66. const llama_hparams & hparams;
  67. const std::unique_ptr<llama_kv_cache_unified> mem_attn;
  68. const std::unique_ptr<llama_memory_recurrent> mem_recr;
  69. };
  70. class llama_memory_hybrid_state : public llama_memory_state_i {
  71. public:
  72. // init failure
  73. explicit llama_memory_hybrid_state(llama_memory_status status);
  74. // init full
  75. explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
  76. // init update
  77. explicit llama_memory_hybrid_state(
  78. llama_memory_hybrid * mem,
  79. llama_context * lctx,
  80. bool optimize);
  81. // init success
  82. llama_memory_hybrid_state(
  83. llama_memory_hybrid * mem,
  84. llama_sbatch sbatch,
  85. std::vector<uint32_t> heads_attn,
  86. std::vector<llama_ubatch> ubatches);
  87. ~llama_memory_hybrid_state() = default;
  88. bool next() override;
  89. bool apply() override;
  90. std::vector<int64_t> & out_ids() override;
  91. llama_memory_status get_status() const override;
  92. const llama_ubatch & get_ubatch() const override;
  93. //
  94. // llama_memory_hybrid_state
  95. //
  96. const llama_kv_cache_unified_state * get_state_attn() const;
  97. const llama_memory_recurrent_state * get_state_recr() const;
  98. private:
  99. llama_sbatch sbatch;
  100. // the index of the next ubatch to process
  101. size_t i_next = 0;
  102. std::vector<llama_ubatch> ubatches;
  103. const llama_memory_state_ptr state_attn;
  104. const llama_memory_state_ptr state_recr;
  105. const llama_memory_status status;
  106. };