llama-kv-cache.h 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-memory.h"
  4. class llama_io_write_i;
  5. class llama_io_read_i;
  6. struct llama_kv_cache : public llama_memory_i {
  7. virtual ~llama_kv_cache() = default;
  8. // TODO: move the init_ interfaces to llama_memory_i
  9. // split the input batch into a set of ubatches and verify that they can fit into the cache
  10. // return a state object containing the ubatches and KV cache state required to process them
  11. // check the llama_memory_state_i::get_status() for the result
  12. virtual llama_memory_state_ptr init_batch(
  13. const llama_batch & batch,
  14. uint32_t n_ubatch,
  15. bool embd_pooled,
  16. bool logits_all) = 0;
  17. // simulate full cache, used for allocating worst-case compute buffers
  18. virtual llama_memory_state_ptr init_full() = 0;
  19. // prepare for any pending memory updates, such as shifts, defrags, etc.
  20. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
  21. virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
  22. // getters
  23. virtual bool get_can_shift() const = 0;
  24. bool get_can_edit() const override { return get_can_shift(); }
  25. //
  26. // state write/read
  27. //
  28. virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
  29. virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
  30. };