gguf-split.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #include "llama.h"
  2. #include "common.h"
  3. #include <algorithm>
  4. #include <cmath>
  5. #include <cstdlib>
  6. #include <fstream>
  7. #include <string>
  8. #include <vector>
  9. #include <stdio.h>
  10. #include <string.h>
  11. #include <climits>
  12. #include <stdexcept>
  13. #if defined(_WIN32)
  14. #include <windows.h>
  15. #ifndef PATH_MAX
  16. #define PATH_MAX MAX_PATH
  17. #endif
  18. #include <io.h>
  19. #endif
  20. enum split_operation : uint8_t {
  21. SPLIT_OP_SPLIT,
  22. SPLIT_OP_MERGE,
  23. };
  24. static const char * const LLM_KV_SPLIT_NO = "split.no";
  25. static const char * const LLM_KV_SPLIT_COUNT = "split.count";
  26. static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
  27. struct split_params {
  28. split_operation operation = SPLIT_OP_SPLIT;
  29. int n_split_tensors = 128;
  30. std::string input;
  31. std::string output;
  32. };
  33. static void split_print_usage(const char * executable) {
  34. const split_params default_params;
  35. printf("\n");
  36. printf("usage: %s [options] GGUF_IN GGUF_OUT\n", executable);
  37. printf("\n");
  38. printf("Apply a GGUF operation on IN to OUT.");
  39. printf("\n");
  40. printf("options:\n");
  41. printf(" -h, --help show this help message and exit\n");
  42. printf(" --version show version and build info\n");
  43. printf(" --split split GGUF to multiple GGUF (default)\n");
  44. printf(" --split-max-tensors max tensors in each split: default(%d)\n", default_params.n_split_tensors);
  45. printf(" --merge merge multiple GGUF to a single GGUF\n");
  46. printf("\n");
  47. }
  48. static bool split_params_parse_ex(int argc, const char ** argv, split_params & params) {
  49. std::string arg;
  50. const std::string arg_prefix = "--";
  51. bool invalid_param = false;
  52. int arg_idx = 1;
  53. for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
  54. arg = argv[arg_idx];
  55. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  56. std::replace(arg.begin(), arg.end(), '_', '-');
  57. }
  58. bool arg_found = false;
  59. if (arg == "-h" || arg == "--help") {
  60. split_print_usage(argv[0]);
  61. exit(0);
  62. }
  63. if (arg == "--version") {
  64. fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
  65. fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
  66. exit(0);
  67. }
  68. if (arg == "--merge") {
  69. arg_found = true;
  70. params.operation = SPLIT_OP_MERGE;
  71. }
  72. if (arg == "--split") {
  73. arg_found = true;
  74. params.operation = SPLIT_OP_SPLIT;
  75. }
  76. if (arg == "--split-max-tensors") {
  77. if (++arg_idx >= argc) {
  78. invalid_param = true;
  79. break;
  80. }
  81. arg_found = true;
  82. params.n_split_tensors = atoi(argv[arg_idx]);
  83. }
  84. if (!arg_found) {
  85. throw std::invalid_argument("error: unknown argument: " + arg);
  86. }
  87. }
  88. if (invalid_param) {
  89. throw std::invalid_argument("error: invalid parameter for argument: " + arg);
  90. }
  91. if (argc - arg_idx < 2) {
  92. printf("%s: bad arguments\n", argv[0]);
  93. split_print_usage(argv[0]);
  94. return false;
  95. }
  96. params.input = argv[arg_idx++];
  97. params.output = argv[arg_idx++];
  98. return true;
  99. }
  100. static bool split_params_parse(int argc, const char ** argv, split_params & params) {
  101. bool result = true;
  102. try {
  103. if (!split_params_parse_ex(argc, argv, params)) {
  104. split_print_usage(argv[0]);
  105. exit(EXIT_FAILURE);
  106. }
  107. }
  108. catch (const std::invalid_argument & ex) {
  109. fprintf(stderr, "%s\n", ex.what());
  110. split_print_usage(argv[0]);
  111. exit(EXIT_FAILURE);
  112. }
  113. return result;
  114. }
  115. static void zeros(std::ofstream & file, size_t n) {
  116. char zero = 0;
  117. for (size_t i = 0; i < n; ++i) {
  118. file.write(&zero, 1);
  119. }
  120. }
  121. struct split_strategy {
  122. const split_params params;
  123. std::ifstream & f_input;
  124. struct gguf_context * ctx_gguf;
  125. struct ggml_context * ctx_meta = NULL;
  126. const int n_tensors;
  127. const int n_split;
  128. int i_split = 0;
  129. int i_tensor = 0;
  130. std::vector<uint8_t> read_data;
  131. struct gguf_context * ctx_out;
  132. std::ofstream fout;
  133. split_strategy(const split_params & params,
  134. std::ifstream & f_input,
  135. struct gguf_context * ctx_gguf,
  136. struct ggml_context * ctx_meta) :
  137. params(params),
  138. f_input(f_input),
  139. ctx_gguf(ctx_gguf),
  140. ctx_meta(ctx_meta),
  141. n_tensors(gguf_get_n_tensors(ctx_gguf)),
  142. n_split(std::ceil(1. * n_tensors / params.n_split_tensors)) {
  143. }
  144. bool should_split() const {
  145. return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
  146. }
  147. void split_start() {
  148. ctx_out = gguf_init_empty();
  149. // Save all metadata in first split only
  150. if (i_split == 0) {
  151. gguf_set_kv(ctx_out, ctx_gguf);
  152. }
  153. gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_NO, i_split);
  154. gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_COUNT, n_split);
  155. gguf_set_val_i32(ctx_out, LLM_KV_SPLIT_TENSORS_COUNT, n_tensors);
  156. // populate the original tensors, so we get an initial metadata
  157. for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) {
  158. struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
  159. gguf_add_tensor(ctx_out, meta);
  160. }
  161. char split_path[PATH_MAX] = {0};
  162. llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
  163. fprintf(stderr, "%s: %s ...", __func__, split_path);
  164. fout = std::ofstream(split_path, std::ios::binary);
  165. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  166. auto meta_size = gguf_get_meta_size(ctx_out);
  167. // placeholder for the meta data
  168. ::zeros(fout, meta_size);
  169. i_split++;
  170. }
  171. void next_tensor() {
  172. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  173. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  174. auto n_bytes = ggml_nbytes(t);
  175. if (read_data.size() < n_bytes) {
  176. read_data.resize(n_bytes);
  177. }
  178. auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
  179. f_input.seekg(offset);
  180. f_input.read((char *)read_data.data(), n_bytes);
  181. t->data = read_data.data();
  182. // write tensor data + padding
  183. fout.write((const char *)t->data, n_bytes);
  184. zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
  185. i_tensor++;
  186. }
  187. void split_end() {
  188. // go back to beginning of file and write the updated metadata
  189. fout.seekp(0);
  190. std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
  191. gguf_get_meta_data(ctx_out, data.data());
  192. fout.write((const char *)data.data(), data.size());
  193. fout.close();
  194. gguf_free(ctx_out);
  195. fprintf(stderr, "\033[3Ddone\n");
  196. }
  197. };
  198. static void gguf_split(const split_params & split_params) {
  199. struct ggml_context * ctx_meta = NULL;
  200. struct gguf_init_params params = {
  201. /*.no_alloc = */ true,
  202. /*.ctx = */ &ctx_meta,
  203. };
  204. std::ifstream f_input(split_params.input.c_str(), std::ios::binary);
  205. if (!f_input.is_open()) {
  206. fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_params.input.c_str());
  207. exit(EXIT_FAILURE);
  208. }
  209. auto * ctx_gguf = gguf_init_from_file(split_params.input.c_str(), params);
  210. if (!ctx_gguf) {
  211. fprintf(stderr, "%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
  212. exit(EXIT_FAILURE);
  213. }
  214. split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
  215. char first_split_path[PATH_MAX] = {0};
  216. llama_split_path(first_split_path, sizeof(first_split_path),
  217. split_params.output.c_str(), strategy.i_split, strategy.n_split);
  218. fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n",
  219. __func__, split_params.input.c_str(),
  220. first_split_path,
  221. split_params.n_split_tensors);
  222. strategy.split_start();
  223. while (strategy.i_tensor < strategy.n_tensors) {
  224. strategy.next_tensor();
  225. if (strategy.should_split()) {
  226. strategy.split_end();
  227. strategy.split_start();
  228. }
  229. }
  230. strategy.split_end();
  231. gguf_free(ctx_gguf);
  232. f_input.close();
  233. fprintf(stderr, "%s: %d gguf split written with a total of %d tensors.\n",
  234. __func__, strategy.n_split, strategy.n_tensors);
  235. }
  236. static void gguf_merge(const split_params & split_params) {
  237. fprintf(stderr, "%s: %s -> %s\n",
  238. __func__, split_params.input.c_str(),
  239. split_params.output.c_str());
  240. int n_split = 1;
  241. int total_tensors = 0;
  242. auto * ctx_out = gguf_init_empty();
  243. std::ofstream fout(split_params.output.c_str(), std::ios::binary);
  244. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  245. std::vector<uint8_t> read_data;
  246. std::vector<ggml_context *> ctx_metas;
  247. std::vector<gguf_context *> ctx_ggufs;
  248. char split_path[PATH_MAX] = {0};
  249. strncpy(split_path, split_params.input.c_str(), sizeof(split_path) - 1);
  250. char split_prefix[PATH_MAX] = {0};
  251. // First pass to find KV and tensors metadata
  252. for (int i_split = 0; i_split < n_split; i_split++) {
  253. struct ggml_context * ctx_meta = NULL;
  254. struct gguf_init_params params = {
  255. /*.no_alloc = */ true,
  256. /*.ctx = */ &ctx_meta,
  257. };
  258. if (i_split > 0) {
  259. llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
  260. }
  261. fprintf(stderr, "%s: reading metadata %s ...", __func__, split_path);
  262. auto * ctx_gguf = gguf_init_from_file(split_path, params);
  263. if (!ctx_gguf) {
  264. fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
  265. exit(EXIT_FAILURE);
  266. }
  267. ctx_ggufs.push_back(ctx_gguf);
  268. ctx_metas.push_back(ctx_meta);
  269. if (i_split == 0) {
  270. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
  271. if (key_n_split < 0) {
  272. fprintf(stderr,
  273. "\n%s: input file does not contain %s metadata\n",
  274. __func__,
  275. LLM_KV_SPLIT_COUNT);
  276. gguf_free(ctx_gguf);
  277. ggml_free(ctx_meta);
  278. gguf_free(ctx_out);
  279. fout.close();
  280. exit(EXIT_FAILURE);
  281. }
  282. n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
  283. if (n_split < 1) {
  284. fprintf(stderr,
  285. "\n%s: input file does not contain a valid split count %d\n",
  286. __func__,
  287. n_split);
  288. gguf_free(ctx_gguf);
  289. ggml_free(ctx_meta);
  290. gguf_free(ctx_out);
  291. fout.close();
  292. exit(EXIT_FAILURE);
  293. }
  294. // Verify the file naming and extract split_prefix
  295. if (!llama_split_prefix(split_prefix, sizeof (split_prefix), split_path, i_split, n_split)) {
  296. fprintf(stderr, "\n%s: unexpected input file name: %s"
  297. " i_split=%d"
  298. " n_split=%d\n", __func__,
  299. split_path, i_split, n_split);
  300. gguf_free(ctx_gguf);
  301. ggml_free(ctx_meta);
  302. gguf_free(ctx_out);
  303. fout.close();
  304. exit(EXIT_FAILURE);
  305. }
  306. // Do not trigger merge if we try to merge again the output
  307. gguf_set_val_u16(ctx_gguf, LLM_KV_SPLIT_COUNT, 0);
  308. // Set metadata from the first split
  309. gguf_set_kv(ctx_out, ctx_gguf);
  310. }
  311. auto n_tensors = gguf_get_n_tensors(ctx_gguf);
  312. for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
  313. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  314. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  315. gguf_add_tensor(ctx_out, t);
  316. }
  317. total_tensors += n_tensors;
  318. fprintf(stderr, "\033[3Ddone\n");
  319. }
  320. // placeholder for the meta data
  321. {
  322. auto meta_size = gguf_get_meta_size(ctx_out);
  323. ::zeros(fout, meta_size);
  324. }
  325. // Write tensors data
  326. for (int i_split = 0; i_split < n_split; i_split++) {
  327. llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
  328. std::ifstream f_input(split_path, std::ios::binary);
  329. if (!f_input.is_open()) {
  330. fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_path);
  331. for (uint32_t i = 0; i < ctx_ggufs.size(); i++) {
  332. gguf_free(ctx_ggufs[i]);
  333. ggml_free(ctx_metas[i]);
  334. }
  335. gguf_free(ctx_out);
  336. fout.close();
  337. exit(EXIT_FAILURE);
  338. }
  339. fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path);
  340. auto * ctx_gguf = ctx_ggufs[i_split];
  341. auto * ctx_meta = ctx_metas[i_split];
  342. auto n_tensors = gguf_get_n_tensors(ctx_gguf);
  343. for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
  344. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  345. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  346. auto n_bytes = ggml_nbytes(t);
  347. if (read_data.size() < n_bytes) {
  348. read_data.resize(n_bytes);
  349. }
  350. auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
  351. f_input.seekg(offset);
  352. f_input.read((char *)read_data.data(), n_bytes);
  353. // write tensor data + padding
  354. fout.write((const char *)read_data.data(), n_bytes);
  355. zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
  356. }
  357. gguf_free(ctx_gguf);
  358. ggml_free(ctx_meta);
  359. f_input.close();
  360. fprintf(stderr, "\033[3Ddone\n");
  361. }
  362. {
  363. // go back to beginning of file and write the updated metadata
  364. fout.seekp(0);
  365. std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
  366. gguf_get_meta_data(ctx_out, data.data());
  367. fout.write((const char *)data.data(), data.size());
  368. fout.close();
  369. gguf_free(ctx_out);
  370. }
  371. fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n",
  372. __func__, split_params.output.c_str(), n_split, total_tensors);
  373. }
  374. int main(int argc, const char ** argv) {
  375. if (argc < 3) {
  376. split_print_usage(argv[0]);
  377. }
  378. split_params params;
  379. split_params_parse(argc, argv, params);
  380. switch (params.operation) {
  381. case SPLIT_OP_SPLIT: gguf_split(params);
  382. break;
  383. case SPLIT_OP_MERGE: gguf_merge(params);
  384. break;
  385. default: split_print_usage(argv[0]);
  386. exit(EXIT_FAILURE);
  387. }
  388. return 0;
  389. }