| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- #pragma once
- #include "llama-batch.h"
- #include "llama-graph.h"
- #include "llama-memory.h"
- #include <set>
- #include <vector>
- //
- // llama_memory_recurrent
- //
- // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
- // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
- class llama_memory_recurrent : public llama_memory_i {
- public:
- // this callback is used to filter out layers that should not be included in the cache
- using layer_filter_cb = std::function<bool(int32_t il)>;
- llama_memory_recurrent(
- const llama_model & model,
- layer_filter_cb && filter,
- ggml_type type_r,
- ggml_type type_s,
- bool offload,
- uint32_t mem_size,
- uint32_t n_seq_max);
- ~llama_memory_recurrent() = default;
- //
- // llama_memory_i
- //
- llama_memory_context_ptr init_batch(
- llama_batch_allocr & balloc,
- uint32_t n_ubatch,
- bool embd_all) override;
- llama_memory_context_ptr init_full() override;
- llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
- void clear(bool data) override;
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
- void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
- bool prepare(const std::vector<llama_ubatch> & ubatches);
- // find a contiguous slot of memory cells and emplace the ubatch there
- bool find_slot(const llama_ubatch & ubatch);
- bool get_can_shift() const override;
- // state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
- uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
- uint32_t size = 0; // total number of cells, shared across all sequences
- uint32_t used = 0; // used cells (i.e. at least one seq_id)
- // computed before each graph build
- uint32_t n = 0;
- // first zero-ed state
- int32_t rs_z = -1;
- // TODO: optimize for recurrent state needs
- struct mem_cell {
- llama_pos pos = -1;
- int32_t src = -1; // used to know where states should be copied from
- int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
- int32_t tail = -1;
- std::set<llama_seq_id> seq_id;
- bool has_seq_id(const llama_seq_id & id) const {
- return seq_id.find(id) != seq_id.end();
- }
- bool is_empty() const {
- return seq_id.empty();
- }
- bool is_same_seq(const mem_cell & other) const {
- return seq_id == other.seq_id;
- }
- };
- std::vector<mem_cell> cells;
- // per layer
- std::vector<ggml_tensor *> r_l;
- std::vector<ggml_tensor *> s_l;
- private:
- //const llama_model & model;
- const llama_hparams & hparams;
- const uint32_t n_seq_max = 1;
- std::vector<ggml_context_ptr> ctxs;
- std::vector<ggml_backend_buffer_ptr> bufs;
- size_t total_size() const;
- size_t size_r_bytes() const;
- size_t size_s_bytes() const;
- void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
- void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
- bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
- bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
- };
- class llama_memory_recurrent_context : public llama_memory_context_i {
- public:
- // used for errors
- llama_memory_recurrent_context(llama_memory_status status);
- // used to create a full-cache or update context
- llama_memory_recurrent_context(
- llama_memory_recurrent * mem);
- // used to create a batch processing context from a batch
- llama_memory_recurrent_context(
- llama_memory_recurrent * mem,
- std::vector<llama_ubatch> ubatches);
- virtual ~llama_memory_recurrent_context();
- //
- // llama_memory_context_i
- //
- bool next() override;
- bool apply() override;
- llama_memory_status get_status() const override;
- const llama_ubatch & get_ubatch() const override;
- //
- // llama_memory_recurrent_context specific API
- //
- uint32_t get_n_rs() const;
- uint32_t get_head() const;
- int32_t get_rs_z() const;
- uint32_t get_size() const;
- ggml_tensor * get_r_l(int32_t il) const;
- ggml_tensor * get_s_l(int32_t il) const;
- int32_t s_copy(int i) const;
- private:
- const llama_memory_status status;
- llama_memory_recurrent * mem;
- size_t i_next = 0;
- std::vector<llama_ubatch> ubatches;
- //
- // data needed for building the compute graph for the current ubatch:
- // TODO: extract all the state like `head` and `n` here
- //
- const bool is_full = false;
- };
|