llama-memory.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #pragma once
  2. #include "llama.h"
  3. #include <memory>
  4. #include <functional>
  5. struct llama_ubatch;
  6. class llama_batch_allocr;
  7. class llama_io_write_i;
  8. class llama_io_read_i;
  9. struct llama_memory_params {
  10. // kv cache
  11. ggml_type type_k;
  12. ggml_type type_v;
  13. // use full-size SWA cache
  14. bool swa_full;
  15. };
  16. enum llama_memory_status {
  17. LLAMA_MEMORY_STATUS_SUCCESS = 0,
  18. LLAMA_MEMORY_STATUS_NO_UPDATE,
  19. LLAMA_MEMORY_STATUS_FAILED_PREPARE,
  20. LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
  21. };
  22. // helper function for combining the status of two memory contexts
  23. // useful for implementing hybrid memory types (e.g. iSWA)
  24. llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
  25. // helper function for checking if a memory status indicates a failure
  26. bool llama_memory_status_is_fail(llama_memory_status status);
  27. // the interface for managing the memory context during batch processing
  28. // this interface is implemented per memory type. see:
  29. // - llama_kv_cache_context
  30. // - llama_kv_cache_iswa_context
  31. // ...
  32. //
  33. // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
  34. struct llama_memory_context_i {
  35. virtual ~llama_memory_context_i() = default;
  36. // consume the current ubatch from the context and proceed to the next one
  37. // return false if we are done
  38. virtual bool next() = 0;
  39. // apply the memory state for the current ubatch to the memory object
  40. // return false on failure
  41. virtual bool apply() = 0;
  42. // get the current ubatch
  43. virtual const llama_ubatch & get_ubatch() const = 0;
  44. // get the status of the memory context - used for error handling and checking if any updates would be applied
  45. virtual llama_memory_status get_status() const = 0;
  46. };
  47. using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
  48. // general concept of LLM memory
  49. // the KV cache is a type of LLM memory, but there can be other types
  50. struct llama_memory_i {
  51. // this callback is used to filter out layers that should not be included in the cache
  52. using layer_filter_cb = std::function<bool(int32_t il)>;
  53. // this callback is used to specify which layers should reuse memory from other layers
  54. // return negative value to indicate that the layer il should not reuse memory
  55. using layer_reuse_cb = std::function<int32_t(int32_t il)>;
  56. virtual ~llama_memory_i() = default;
  57. // split the input batch into a set of ubatches and verify that they can fit into the cache
  58. // return a context object containing the ubatches and memory state required to process them
  59. // check the llama_memory_context_i::get_status() for the result
  60. virtual llama_memory_context_ptr init_batch(
  61. llama_batch_allocr & balloc,
  62. uint32_t n_ubatch,
  63. bool embd_all) = 0;
  64. // simulate full cache, used for allocating worst-case compute buffers
  65. virtual llama_memory_context_ptr init_full() = 0;
  66. // prepare for any pending memory updates, such as shifts, copies, etc.
  67. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
  68. virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
  69. // getters
  70. virtual bool get_can_shift() const = 0;
  71. //
  72. // ops
  73. //
  74. // if data == true, the data buffers will also be cleared together with the metadata
  75. virtual void clear(bool data) = 0;
  76. virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
  77. virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
  78. virtual void seq_keep(llama_seq_id seq_id) = 0;
  79. virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
  80. virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
  81. virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
  82. virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
  83. //
  84. // state write/read
  85. //
  86. virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
  87. virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
  88. };
  89. using llama_memory_ptr = std::unique_ptr<llama_memory_i>;