llama-model-loader.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #pragma once
  2. #include "llama.h"
  3. #include "llama-impl.h"
  4. #include "llama-arch.h"
  5. #include "llama-mmap.h"
  6. #include "ggml-cpp.h"
  7. #include <cstddef>
  8. #include <map>
  9. #include <stdexcept>
  10. #include <unordered_map>
  11. using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
  12. enum llama_fver {
  13. GGUF_FILE_VERSION_V1 = 1,
  14. GGUF_FILE_VERSION_V2 = 2,
  15. GGUF_FILE_VERSION_V3 = 3,
  16. };
  17. const char * llama_file_version_name(llama_fver version);
  18. struct llama_model_loader {
  19. // Holds information on a model weight
  20. struct llama_tensor_weight {
  21. uint16_t idx; // source file index
  22. size_t offs; // tensor data offset in the original file
  23. ggml_tensor * tensor;
  24. llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
  25. const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor));
  26. if (tensor_idx < 0) {
  27. throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
  28. }
  29. offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
  30. if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
  31. throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
  32. }
  33. }
  34. };
  35. // custom comparator to sort weights more nicely by layer
  36. struct weight_name_comparer {
  37. bool operator()(const std::string & a, const std::string & b) const {
  38. int a_layer = -1;
  39. int b_layer = -1;
  40. sscanf(a.c_str(), "blk.%d.", &a_layer);
  41. sscanf(b.c_str(), "blk.%d.", &b_layer);
  42. if (a_layer != b_layer) {
  43. return a_layer < b_layer;
  44. }
  45. return a < b;
  46. }
  47. };
  48. static const int TENSOR_NOT_REQUIRED = 1;
  49. static const int TENSOR_DUPLICATED = 2;
  50. int n_kv = 0;
  51. int n_tensors = 0;
  52. int n_created = 0;
  53. uint64_t n_elements = 0;
  54. size_t n_bytes = 0;
  55. bool use_mmap = false;
  56. bool check_tensors;
  57. llama_files files;
  58. llama_ftype ftype;
  59. llama_fver fver;
  60. llama_mmaps mappings;
  61. std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
  62. std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
  63. gguf_context_ptr meta;
  64. std::vector<ggml_context_ptr> contexts;
  65. std::string arch_name;
  66. LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
  67. size_t size_done = 0;
  68. size_t size_data = 0;
  69. std::vector<std::pair<size_t, size_t>> mmaps_used;
  70. llama_model_loader(
  71. const std::string & fname,
  72. std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
  73. bool use_mmap,
  74. bool check_tensors,
  75. const struct llama_model_kv_override * param_overrides_p);
  76. template<typename T>
  77. typename std::enable_if<std::is_integral<T>::value, bool>::type
  78. get_arr_n(const std::string & key, T & result, bool required = true);
  79. template<typename T>
  80. typename std::enable_if<std::is_integral<T>::value, bool>::type
  81. get_arr_n(enum llm_kv kid, T & result, bool required = true);
  82. template<typename T>
  83. bool get_arr(const std::string & key, std::vector<T> & result, bool required = true);
  84. template<typename T, size_t N_MAX>
  85. bool get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required = true);
  86. template<typename T>
  87. bool get_arr(enum llm_kv kid, T & result, bool required = true);
  88. template<typename T>
  89. bool get_key(const std::string & key, T & result, bool required = true);
  90. template<typename T>
  91. bool get_key(enum llm_kv kid, T & result, bool required = true);
  92. template<typename T, size_t N_MAX>
  93. bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required = true);
  94. template<typename T>
  95. bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
  96. std::string get_arch_name() const;
  97. enum llm_arch get_arch() const;
  98. const llama_tensor_weight * get_weight(const char * name) const;
  99. const llama_tensor_weight & require_weight(const char * name) const;
  100. struct ggml_tensor * get_tensor_meta(const char * name) const;
  101. struct ggml_tensor * require_tensor_meta(const std::string & name) const;
  102. const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
  103. struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0);
  104. struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
  105. void done_getting_tensors() const;
  106. void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
  107. void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
  108. // for backwards compatibility, does not support ggml-backend
  109. void load_data_for(struct ggml_tensor * cur) const;
  110. // Returns false if cancelled by progress_callback
  111. bool load_all_data(
  112. struct ggml_context * ctx,
  113. llama_buf_map & bufs,
  114. llama_mlocks * lmlocks,
  115. llama_progress_callback progress_callback,
  116. void * progress_callback_user_data);
  117. std::string ftype_name() const;
  118. void print_info() const;
  119. };