| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- #pragma once
- #include "llama-batch.h"
- #include "llama-graph.h"
- #include "llama-kv-cache.h"
- #include "llama-memory.h"
- #include "llama-memory-recurrent.h"
- #include <memory>
- #include <vector>
- //
- // llama_memory_hybrid
- //
- // utilizes instances of llama_memory_recurrent and llama_kv_cache to
- // support models where each layer may be either attention-based or recurrent
- class llama_memory_hybrid : public llama_memory_i {
- public:
- llama_memory_hybrid(
- const llama_model & model,
- /* attn */
- ggml_type type_k,
- ggml_type type_v,
- bool v_trans,
- uint32_t kv_size,
- uint32_t n_pad,
- uint32_t n_swa,
- llama_swa_type swa_type,
- /* recurrent */
- ggml_type type_r,
- ggml_type type_s,
- uint32_t rs_size,
- /* common */
- uint32_t n_seq_max,
- bool offload,
- bool unified,
- /* layer filters */
- const layer_filter_cb & filter_attn = nullptr,
- const layer_filter_cb & filter_recr = nullptr);
- ~llama_memory_hybrid() = default;
- //
- // llama_memory_i
- //
- llama_memory_context_ptr init_batch(
- llama_batch_allocr & balloc,
- uint32_t n_ubatch,
- bool embd_all) override;
- llama_memory_context_ptr init_full() override;
- llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
- bool get_can_shift() const override;
- void clear(bool data) override;
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
- void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
- std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
- // state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
- //
- // llama_memory_hybrid specific API
- //
- llama_kv_cache * get_mem_attn() const;
- llama_memory_recurrent * get_mem_recr() const;
- private:
- const llama_hparams & hparams;
- const std::unique_ptr<llama_kv_cache> mem_attn;
- const std::unique_ptr<llama_memory_recurrent> mem_recr;
- };
- class llama_memory_hybrid_context : public llama_memory_context_i {
- public:
- using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
- // init failure
- explicit llama_memory_hybrid_context(llama_memory_status status);
- // init full
- explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
- // init update
- explicit llama_memory_hybrid_context(
- llama_memory_hybrid * mem,
- llama_context * lctx,
- bool optimize);
- // init success
- llama_memory_hybrid_context(
- llama_memory_hybrid * mem,
- slot_info_vec_t sinfos_attn,
- std::vector<llama_ubatch> ubatches);
- ~llama_memory_hybrid_context() = default;
- bool next() override;
- bool apply() override;
- llama_memory_status get_status() const override;
- const llama_ubatch & get_ubatch() const override;
- //
- // llama_memory_hybrid_context
- //
- const llama_kv_cache_context * get_attn() const;
- const llama_memory_recurrent_context * get_recr() const;
- private:
- // the index of the next ubatch to process
- size_t i_next = 0;
- std::vector<llama_ubatch> ubatches;
- const llama_memory_context_ptr ctx_attn;
- const llama_memory_context_ptr ctx_recr;
- const llama_memory_status status;
- };
|