export-lora.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #include "common.h"
  2. #include "ggml.h"
  3. #include "ggml-alloc.h"
  4. #include <vector>
  5. #include <string>
  6. #include <thread>
  7. static const size_t tensor_alignment = 32;
  8. struct lora_info {
  9. std::string filename;
  10. float scale;
  11. };
  12. struct export_lora_params {
  13. std::string fn_model_base;
  14. std::string fn_model_out;
  15. std::vector<struct lora_info> lora;
  16. int n_threads;
  17. };
  18. struct lora_data {
  19. struct lora_info info;
  20. std::vector<uint8_t> data;
  21. struct ggml_context * ctx;
  22. uint32_t lora_r;
  23. uint32_t lora_alpha;
  24. };
  25. struct llama_file {
  26. // use FILE * so we don't have to re-open the file to mmap
  27. FILE * fp;
  28. size_t size;
  29. llama_file(const char * fname, const char * mode) {
  30. fp = std::fopen(fname, mode);
  31. if (fp == NULL) {
  32. size = 0;
  33. } else {
  34. seek(0, SEEK_END);
  35. size = tell();
  36. seek(0, SEEK_SET);
  37. }
  38. }
  39. size_t tell() const {
  40. #ifdef _WIN32
  41. __int64 ret = _ftelli64(fp);
  42. #else
  43. long ret = std::ftell(fp);
  44. #endif
  45. GGML_ASSERT(ret != -1); // this really shouldn't fail
  46. return (size_t) ret;
  47. }
  48. void seek(size_t offset, int whence) {
  49. #ifdef _WIN32
  50. int ret = _fseeki64(fp, (__int64) offset, whence);
  51. #else
  52. int ret = std::fseek(fp, (long) offset, whence);
  53. #endif
  54. GGML_ASSERT(ret == 0); // same
  55. }
  56. void read_raw(void * ptr, size_t size) {
  57. if (size == 0) {
  58. return;
  59. }
  60. errno = 0;
  61. std::size_t ret = std::fread(ptr, size, 1, fp);
  62. if (ferror(fp)) {
  63. die_fmt("read error: %s", strerror(errno));
  64. }
  65. if (ret != 1) {
  66. die("unexpectedly reached end of file");
  67. }
  68. }
  69. std::uint32_t read_u32() {
  70. std::uint32_t ret;
  71. read_raw(&ret, sizeof(ret));
  72. return ret;
  73. }
  74. std::string read_string(std::uint32_t len) {
  75. std::vector<char> chars(len);
  76. read_raw(chars.data(), len);
  77. return std::string(chars.data(), len);
  78. }
  79. void write_raw(const void * ptr, size_t size) {
  80. if (size == 0) {
  81. return;
  82. }
  83. errno = 0;
  84. size_t ret = std::fwrite(ptr, size, 1, fp);
  85. if (ret != 1) {
  86. die_fmt("write error: %s", strerror(errno));
  87. }
  88. }
  89. void write_u32(std::uint32_t val) {
  90. write_raw(&val, sizeof(val));
  91. }
  92. bool eof() {
  93. return tell() >= size;
  94. }
  95. ~llama_file() {
  96. if (fp) {
  97. std::fclose(fp);
  98. }
  99. }
  100. };
  101. static struct export_lora_params get_default_export_lora_params() {
  102. struct export_lora_params result;
  103. result.fn_model_base = "";
  104. result.fn_model_out = "";
  105. result.n_threads = GGML_DEFAULT_N_THREADS;
  106. return result;
  107. }
  108. static void export_lora_print_usage(int /*argc*/, char ** argv, const struct export_lora_params * params) {
  109. fprintf(stderr, "usage: %s [options]\n", argv[0]);
  110. fprintf(stderr, "\n");
  111. fprintf(stderr, "options:\n");
  112. fprintf(stderr, " -h, --help show this help message and exit\n");
  113. fprintf(stderr, " -m FNAME, --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base.c_str());
  114. fprintf(stderr, " -o FNAME, --model-out FNAME path to save exported model (default '%s')\n", params->fn_model_out.c_str());
  115. fprintf(stderr, " -l FNAME, --lora FNAME apply LoRA adapter\n");
  116. fprintf(stderr, " -s FNAME S, --lora-scaled FNAME S apply LoRA adapter with user defined scaling S\n");
  117. fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params->n_threads);
  118. }
  119. static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_params * params) {
  120. bool invalid_param = false;
  121. std::string arg;
  122. struct export_lora_params default_params = get_default_export_lora_params();
  123. const std::string arg_prefix = "--";
  124. for (int i = 1; i < argc; i++) {
  125. arg = argv[i];
  126. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  127. std::replace(arg.begin(), arg.end(), '_', '-');
  128. }
  129. if (arg == "-m" || arg == "--model-base") {
  130. if (++i >= argc) {
  131. invalid_param = true;
  132. break;
  133. }
  134. params->fn_model_base = argv[i];
  135. } else if (arg == "-o" || arg == "--model-out") {
  136. if (++i >= argc) {
  137. invalid_param = true;
  138. break;
  139. }
  140. params->fn_model_out = argv[i];
  141. } else if (arg == "-l" || arg == "--lora") {
  142. if (++i >= argc) {
  143. invalid_param = true;
  144. break;
  145. }
  146. struct lora_info lora;
  147. lora.filename = argv[i];
  148. lora.scale = 1.0f;
  149. params->lora.push_back(lora);
  150. } else if (arg == "-s" || arg == "--lora-scaled") {
  151. if (++i >= argc) {
  152. invalid_param = true;
  153. break;
  154. }
  155. struct lora_info lora;
  156. lora.filename = argv[i];
  157. if (++i >= argc) {
  158. invalid_param = true;
  159. break;
  160. }
  161. lora.scale = std::stof(argv[i]);
  162. params->lora.push_back(lora);
  163. } else if (arg == "-t" || arg == "--threads") {
  164. if (++i >= argc) {
  165. invalid_param = true;
  166. break;
  167. }
  168. params->n_threads = std::stoi(argv[i]);
  169. if (params->n_threads <= 0) {
  170. params->n_threads = std::thread::hardware_concurrency();
  171. }
  172. } else {
  173. fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str());
  174. export_lora_print_usage(argc, argv, &default_params);
  175. exit(1);
  176. }
  177. }
  178. if (params->fn_model_base == default_params.fn_model_base) {
  179. fprintf(stderr, "error: please specify a filename for model-base.\n");
  180. export_lora_print_usage(argc, argv, &default_params);
  181. exit(1);
  182. }
  183. if (params->fn_model_out == default_params.fn_model_out) {
  184. fprintf(stderr, "error: please specify a filename for model-out.\n");
  185. export_lora_print_usage(argc, argv, &default_params);
  186. exit(1);
  187. }
  188. if (invalid_param) {
  189. fprintf(stderr, "error: invalid parameter for argument: '%s'\n", arg.c_str());
  190. export_lora_print_usage(argc, argv, &default_params);
  191. exit(1);
  192. }
  193. return true;
  194. }
  195. static void free_lora(struct lora_data * lora) {
  196. if (lora->ctx != NULL) {
  197. ggml_free(lora->ctx);
  198. }
  199. delete lora;
  200. }
  201. static struct lora_data * load_lora(struct lora_info * info) {
  202. struct lora_data * result = new struct lora_data;
  203. result->info = *info;
  204. result->ctx = NULL;
  205. result->lora_r = 1;
  206. result->lora_alpha = 1;
  207. struct llama_file file(info->filename.c_str(), "rb");
  208. if (file.fp == NULL) {
  209. fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
  210. info->filename.c_str());
  211. free_lora(result);
  212. return NULL;
  213. }
  214. struct ggml_init_params params_ggml;
  215. params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
  216. params_ggml.mem_buffer = NULL;
  217. params_ggml.no_alloc = true;
  218. result->ctx = ggml_init(params_ggml);
  219. uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
  220. uint32_t magic = file.read_u32();
  221. if (magic != LLAMA_FILE_MAGIC_LORA) {
  222. die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
  223. }
  224. uint32_t version = file.read_u32();
  225. if (version != 1) {
  226. die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str());
  227. }
  228. result->lora_r = file.read_u32();
  229. result->lora_alpha = file.read_u32();
  230. // read tensor infos from file
  231. std::vector<char> name_buf;
  232. std::vector<struct ggml_tensor *> tensors;
  233. std::vector<size_t> tensors_offset;
  234. size_t total_nbytes_pad = 0;
  235. while(!file.eof()) {
  236. int64_t ne[4] = {1,1,1,1};
  237. uint32_t n_dims = file.read_u32();
  238. uint32_t namelen = file.read_u32();
  239. uint32_t type = file.read_u32();
  240. for (uint32_t k = 0; k < n_dims; ++k) {
  241. ne[k] = (int64_t)file.read_u32();
  242. }
  243. name_buf.clear();
  244. name_buf.resize(namelen + 1, '\0');
  245. file.read_raw(name_buf.data(), namelen);
  246. file.seek((0-file.tell()) & 31, SEEK_CUR);
  247. size_t offset = file.tell();
  248. struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne);
  249. ggml_set_name(tensor, name_buf.data());
  250. size_t nbytes = ggml_nbytes(tensor);
  251. size_t nbytes_pad = ggml_nbytes_pad(tensor);
  252. total_nbytes_pad += nbytes_pad;
  253. tensors.push_back(tensor);
  254. tensors_offset.push_back(offset);
  255. file.seek(nbytes, SEEK_CUR);
  256. }
  257. // read tensor data
  258. result->data.resize(total_nbytes_pad);
  259. size_t data_offset = 0;
  260. for (size_t i = 0; i < tensors.size(); ++i) {
  261. struct ggml_tensor * tensor = tensors[i];
  262. size_t offset = tensors_offset[i];
  263. size_t nbytes = ggml_nbytes(tensor);
  264. size_t nbytes_pad = ggml_nbytes_pad(tensor);
  265. file.seek(offset, SEEK_SET);
  266. tensor->data = result->data.data() + data_offset;
  267. file.read_raw(tensor->data, nbytes);
  268. data_offset += nbytes_pad;
  269. }
  270. return result;
  271. }
  272. static struct ggml_cgraph * build_graph_lora(
  273. struct ggml_context * ctx,
  274. struct ggml_tensor * tensor,
  275. struct ggml_tensor * lora_a,
  276. struct ggml_tensor * lora_b,
  277. float scaling
  278. ) {
  279. struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
  280. if (scaling != 1.0f) {
  281. ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
  282. }
  283. struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
  284. struct ggml_cgraph * gf = ggml_new_graph(ctx);
  285. ggml_build_forward_expand (gf, res);
  286. return gf;
  287. }
  288. static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int n_threads) {
  289. if (lora->ctx == NULL) {
  290. return false;
  291. }
  292. std::string name = ggml_get_name(tensor);
  293. std::string name_a = name + std::string(".loraA");
  294. std::string name_b = name + std::string(".loraB");
  295. struct ggml_tensor * lora_a = ggml_get_tensor(lora->ctx, name_a.c_str());
  296. struct ggml_tensor * lora_b = ggml_get_tensor(lora->ctx, name_b.c_str());
  297. if (lora_a == NULL || lora_b == NULL) {
  298. return false;
  299. }
  300. float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
  301. struct ggml_init_params params;
  302. params.mem_size = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
  303. params.mem_buffer = NULL;
  304. params.no_alloc = true;
  305. struct ggml_context * ctx = NULL;
  306. struct ggml_allocr * alloc = NULL;
  307. struct ggml_cgraph * gf = NULL;
  308. ctx = ggml_init(params);
  309. alloc = ggml_allocr_new_measure(tensor_alignment);
  310. gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
  311. size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
  312. ggml_allocr_free(alloc);
  313. ggml_free(ctx);
  314. static std::vector<uint8_t> data_compute;
  315. data_compute.resize(alloc_size + tensor_alignment);
  316. ctx = ggml_init(params);
  317. alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
  318. gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
  319. ggml_allocr_alloc_graph(alloc, gf);
  320. ggml_allocr_free(alloc);
  321. struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
  322. static std::vector<uint8_t> data_work;
  323. data_work.resize(cplan.work_size);
  324. cplan.work_data = data_work.data();
  325. ggml_graph_compute(gf, &cplan);
  326. ggml_free(ctx);
  327. return true;
  328. }
  329. static void export_lora(struct export_lora_params * params) {
  330. // load all loras
  331. std::vector<struct lora_data *> loras;
  332. for (size_t i = 0; i < params->lora.size(); ++i) {
  333. struct lora_data * lora = load_lora(&params->lora[i]);
  334. if (lora != NULL) {
  335. loras.push_back(lora);
  336. }
  337. }
  338. if (loras.size() == 0) {
  339. fprintf(stderr, "warning: no lora adapters will be applied.\n");
  340. }
  341. // open input file
  342. struct llama_file fin(params->fn_model_base.c_str(), "rb");
  343. if (!fin.fp) {
  344. die_fmt("Could not open file '%s'\n", params->fn_model_base.c_str());
  345. }
  346. // open base model gguf, read tensors without their data
  347. struct ggml_context * ctx_in;
  348. struct gguf_init_params params_gguf;
  349. params_gguf.no_alloc = true;
  350. params_gguf.ctx = &ctx_in;
  351. struct gguf_context * gguf_in = gguf_init_from_file(params->fn_model_base.c_str(), params_gguf);
  352. // create new gguf
  353. struct gguf_context * gguf_out = gguf_init_empty();
  354. // copy meta data from base model: kv and tensors
  355. gguf_set_kv(gguf_out, gguf_in);
  356. int n_tensors = gguf_get_n_tensors(gguf_in);
  357. for (int i=0; i < n_tensors; ++i) {
  358. const char * name = gguf_get_tensor_name(gguf_in, i);
  359. struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
  360. gguf_add_tensor(gguf_out, tensor);
  361. }
  362. // create output file
  363. struct llama_file fout(params->fn_model_out.c_str(), "wb");
  364. if (!fout.fp) {
  365. die_fmt("Could not create file '%s'\n", params->fn_model_out.c_str());
  366. }
  367. // write gguf meta data
  368. std::vector<uint8_t> meta;
  369. meta.resize(gguf_get_meta_size(gguf_out));
  370. gguf_get_meta_data(gguf_out, meta.data());
  371. fout.write_raw(meta.data(), meta.size());
  372. std::vector<uint8_t> data;
  373. std::vector<uint8_t> padding;
  374. for (int i=0; i < n_tensors; ++i) {
  375. const char * name = gguf_get_tensor_name(gguf_in, i);
  376. struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
  377. // read tensor data
  378. data.resize(ggml_nbytes(tensor));
  379. tensor->data = data.data();
  380. size_t offset = gguf_get_tensor_offset(gguf_in, i);
  381. fin.seek(offset + meta.size(), SEEK_SET);
  382. fin.read_raw(data.data(), data.size());
  383. // apply all loras
  384. for (size_t k = 0; k < loras.size(); ++k) {
  385. apply_lora(tensor, loras[k], params->n_threads);
  386. }
  387. // write tensor data + padding
  388. padding.clear();
  389. padding.resize(GGML_PAD(data.size(), gguf_get_alignment(gguf_out)) - data.size(), 0);
  390. GGML_ASSERT(fout.tell() == offset + meta.size());
  391. // fout.seek(offset + meta.size(), SEEK_SET);
  392. fout.write_raw(data.data(), data.size());
  393. fout.write_raw(padding.data(), padding.size());
  394. if (i % 2 == 0) {
  395. printf(".");
  396. }
  397. }
  398. printf("\n");
  399. // close gguf
  400. gguf_free(gguf_out);
  401. gguf_free(gguf_in);
  402. // free loras
  403. for (size_t i = 0; i < loras.size(); ++i) {
  404. free_lora(loras[i]);
  405. }
  406. }
  407. int main(int argc, char ** argv) {
  408. struct export_lora_params params = get_default_export_lora_params();
  409. if (!export_lora_params_parse(argc, argv, &params)) {
  410. return 1;
  411. }
  412. export_lora(&params);
  413. return 0;
  414. }