gguf-split.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. #include "llama.h"
  2. #include "ggml.h"
  3. #include "common.h"
  4. #include <algorithm>
  5. #include <cmath>
  6. #include <cstdint>
  7. #include <cstdlib>
  8. #include <fstream>
  9. #include <ios>
  10. #include <string>
  11. #include <vector>
  12. #include <stdio.h>
  13. #include <fcntl.h>
  14. #include <unistd.h>
  15. #include <string.h>
  16. #include <unistd.h>
  17. enum split_operation : uint8_t {
  18. SPLIT_OP_SPLIT,
  19. SPLIT_OP_MERGE,
  20. };
  21. static const char * const LLM_KV_GENERAL_SPLIT_I_SPLIT = "general.split";
  22. static const char * const LLM_KV_GENERAL_SPLIT_N_SPLIT = "general.split_count";
  23. static const int SPLIT_FILENAME_MAX = 256;
  24. static const char * const SPLIT_FILENAME_FORMAT = "%s-%05d-of-%05d.gguf";
  25. struct split_params {
  26. split_operation operation = SPLIT_OP_SPLIT;
  27. int n_split_tensors = 128;
  28. std::string input;
  29. std::string output;
  30. };
  31. static void split_print_usage(const char * executable) {
  32. const split_params default_params;
  33. printf("\n");
  34. printf("usage: %s [options] GGUF_IN GGUF_OUT\n", executable);
  35. printf("\n");
  36. printf("Apply a GGUF operation on IN to OUT.");
  37. printf("\n");
  38. printf("options:\n");
  39. printf(" -h, --help show this help message and exit\n");
  40. printf(" --version show version and build info\n");
  41. printf(" --split split GGUF to multiple GGUF (default)\n");
  42. printf(" --split-max-tensors max tensors in each split: default(%d)\n", default_params.n_split_tensors);
  43. printf(" --merge merge multiple GGUF to a single GGUF\n");
  44. printf("\n");
  45. }
  46. static bool split_params_parse_ex(int argc, const char ** argv, split_params & params) {
  47. std::string arg;
  48. const std::string arg_prefix = "--";
  49. bool invalid_param = false;
  50. int arg_idx = 1;
  51. for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
  52. arg = argv[arg_idx];
  53. if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
  54. std::replace(arg.begin(), arg.end(), '_', '-');
  55. }
  56. bool arg_found = false;
  57. if (arg == "-h" || arg == "--help") {
  58. split_print_usage(argv[0]);
  59. exit(0);
  60. }
  61. if (arg == "--version") {
  62. fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
  63. fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
  64. exit(0);
  65. }
  66. if (arg == "--merge") {
  67. arg_found = true;
  68. params.operation = SPLIT_OP_MERGE;
  69. }
  70. if (arg == "--split") {
  71. arg_found = true;
  72. params.operation = SPLIT_OP_SPLIT;
  73. }
  74. if (arg == "--split-max-tensors") {
  75. if (++arg_idx >= argc) {
  76. invalid_param = true;
  77. break;
  78. }
  79. arg_found = true;
  80. params.n_split_tensors = atoi(argv[arg_idx]);
  81. }
  82. if (!arg_found) {
  83. throw std::invalid_argument("error: unknown argument: " + arg);
  84. }
  85. }
  86. if (invalid_param) {
  87. throw std::invalid_argument("error: invalid parameter for argument: " + arg);
  88. }
  89. if (argc - arg_idx < 2) {
  90. printf("%s: bad arguments\n", argv[0]);
  91. split_print_usage(argv[0]);
  92. return false;
  93. }
  94. params.input = argv[arg_idx++];
  95. params.output = argv[arg_idx++];
  96. return true;
  97. }
  98. static bool split_params_parse(int argc, const char ** argv, split_params & params) {
  99. bool result = true;
  100. try {
  101. if (!split_params_parse_ex(argc, argv, params)) {
  102. split_print_usage(argv[0]);
  103. exit(1);
  104. }
  105. }
  106. catch (const std::invalid_argument & ex) {
  107. fprintf(stderr, "%s\n", ex.what());
  108. split_print_usage(argv[0]);
  109. exit(1);
  110. }
  111. return result;
  112. }
  113. static void zeros(std::ofstream & file, size_t n) {
  114. char zero = 0;
  115. for (size_t i = 0; i < n; ++i) {
  116. file.write(&zero, 1);
  117. }
  118. }
  119. static std::string split_file_name(const std::string & path, int i_split, int n_split) {
  120. char f_split[SPLIT_FILENAME_MAX] = {0};
  121. snprintf(f_split, sizeof(f_split), SPLIT_FILENAME_FORMAT, path.c_str(), i_split + 1, n_split);
  122. return std::string(f_split);
  123. }
  124. struct split_strategy {
  125. const split_params params;
  126. std::ifstream & f_input;
  127. struct gguf_context * ctx_gguf;
  128. struct ggml_context * ctx_meta = NULL;
  129. const int n_tensors;
  130. const int n_split;
  131. int i_split = 0;
  132. int i_tensor = 0;
  133. std::vector<uint8_t> read_data;
  134. struct gguf_context * ctx_out;
  135. std::ofstream fout;
  136. split_strategy(const split_params & params,
  137. std::ifstream & f_input,
  138. struct gguf_context * ctx_gguf,
  139. struct ggml_context * ctx_meta) :
  140. params(params),
  141. f_input(f_input),
  142. ctx_gguf(ctx_gguf),
  143. ctx_meta(ctx_meta),
  144. n_tensors(gguf_get_n_tensors(ctx_gguf)),
  145. n_split(std::ceil(1. * n_tensors / params.n_split_tensors)) {
  146. }
  147. bool should_split() const {
  148. return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
  149. }
  150. void split_start() {
  151. ctx_out = gguf_init_empty();
  152. // Save all metadata in first split only
  153. if (i_split == 0) {
  154. gguf_set_kv(ctx_out, ctx_gguf);
  155. }
  156. gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_I_SPLIT, i_split);
  157. gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, n_split);
  158. // populate the original tensors, so we get an initial metadata
  159. for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) {
  160. struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
  161. gguf_add_tensor(ctx_out, meta);
  162. }
  163. auto split_name = split_file_name(params.output, i_split, n_split);
  164. fprintf(stderr, "%s: %s ...", __func__, split_name.c_str());
  165. fout = std::ofstream(split_name, std::ios::binary);
  166. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  167. auto meta_size = gguf_get_meta_size(ctx_out);
  168. // placeholder for the meta data
  169. ::zeros(fout, meta_size);
  170. i_split++;
  171. }
  172. void next_tensor() {
  173. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  174. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  175. auto n_bytes = ggml_nbytes(t);
  176. if (read_data.size() < n_bytes) {
  177. read_data.resize(n_bytes);
  178. }
  179. auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
  180. f_input.seekg(offset);
  181. f_input.read((char *)read_data.data(), n_bytes);
  182. t->data = read_data.data();
  183. // write tensor data + padding
  184. fout.write((const char *)t->data, n_bytes);
  185. zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
  186. i_tensor++;
  187. }
  188. void split_end() {
  189. // go back to beginning of file and write the updated metadata
  190. fout.seekp(0);
  191. std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
  192. gguf_get_meta_data(ctx_out, data.data());
  193. fout.write((const char *)data.data(), data.size());
  194. fout.close();
  195. gguf_free(ctx_out);
  196. fprintf(stderr, "\033[3Ddone\n");
  197. }
  198. };
  199. static void gguf_split(const split_params & split_params) {
  200. struct ggml_context * ctx_meta = NULL;
  201. struct gguf_init_params params = {
  202. /*.no_alloc = */ true,
  203. /*.ctx = */ &ctx_meta,
  204. };
  205. std::ifstream f_input(split_params.input.c_str(), std::ios::binary);
  206. if (!f_input.is_open()) {
  207. fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_params.input.c_str());
  208. exit(1);
  209. }
  210. auto * ctx_gguf = gguf_init_from_file(split_params.input.c_str(), params);
  211. if (!ctx_gguf) {
  212. fprintf(stderr, "%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
  213. exit(1);
  214. }
  215. split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
  216. fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n",
  217. __func__, split_params.input.c_str(),
  218. split_file_name(split_params.output, strategy.i_split, strategy.n_split).c_str(),
  219. split_params.n_split_tensors);
  220. strategy.split_start();
  221. while (strategy.i_tensor < strategy.n_tensors) {
  222. strategy.next_tensor();
  223. if (strategy.should_split()) {
  224. strategy.split_end();
  225. strategy.split_start();
  226. }
  227. }
  228. strategy.split_end();
  229. gguf_free(ctx_gguf);
  230. f_input.close();
  231. fprintf(stderr, "%s: %d gguf split written with a total of %d tensors.\n",
  232. __func__, strategy.n_split, strategy.n_tensors);
  233. }
  234. static void gguf_merge(const split_params & split_params) {
  235. fprintf(stderr, "%s: %s -> %s\n",
  236. __func__, split_params.input.c_str(),
  237. split_params.output.c_str());
  238. int n_split = 1;
  239. int total_tensors = 0;
  240. auto * ctx_out = gguf_init_empty();
  241. std::ofstream fout(split_params.output.c_str(), std::ios::binary);
  242. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  243. std::vector<uint8_t> read_data;
  244. std::vector<ggml_context *> ctx_metas;
  245. std::vector<gguf_context *> ctx_ggufs;
  246. std::string split_prefix;
  247. // First pass to find KV and tensors metadata
  248. for (int i_split = 0; i_split < n_split; i_split++) {
  249. struct ggml_context * ctx_meta = NULL;
  250. struct gguf_init_params params = {
  251. /*.no_alloc = */ true,
  252. /*.ctx = */ &ctx_meta,
  253. };
  254. auto split_name = split_params.input;
  255. if (i_split > 0) {
  256. split_name = split_file_name(split_prefix, i_split, n_split);
  257. }
  258. fprintf(stderr, "%s: reading metadata %s ...", __func__, split_name.c_str());
  259. auto * ctx_gguf = gguf_init_from_file(split_name.c_str(), params);
  260. if (!ctx_gguf) {
  261. fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
  262. exit(1);
  263. }
  264. ctx_ggufs.push_back(ctx_gguf);
  265. ctx_metas.push_back(ctx_meta);
  266. if (i_split == 0) {
  267. auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_GENERAL_SPLIT_N_SPLIT);
  268. if (key_n_split < 0) {
  269. fprintf(stderr,
  270. "\n%s: input file does not contain %s metadata\n",
  271. __func__,
  272. LLM_KV_GENERAL_SPLIT_N_SPLIT);
  273. gguf_free(ctx_gguf);
  274. gguf_free(ctx_out);
  275. fout.close();
  276. exit(1);
  277. }
  278. n_split = gguf_get_val_u8(ctx_gguf, key_n_split);
  279. if (n_split < 1) {
  280. fprintf(stderr,
  281. "\n%s: input file does not contain a valid split count %d\n",
  282. __func__,
  283. n_split);
  284. gguf_free(ctx_gguf);
  285. gguf_free(ctx_out);
  286. fout.close();
  287. exit(1);
  288. }
  289. // Do not trigger merge if we try to merge again the output
  290. gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, 0);
  291. // Set metadata from the first split
  292. gguf_set_kv(ctx_out, ctx_gguf);
  293. }
  294. // Verify the file naming
  295. {
  296. int i_split_file = 0;
  297. int n_split_file = 0;
  298. const char * i_split_format = "-00000-of-00000.gguf";
  299. if (split_name.size() < strlen(i_split_format)) {
  300. fprintf(stderr, "\n%s: unexpected input file name: %s\n", __func__, split_params.input.c_str());
  301. for (auto * _ctx_gguf : ctx_ggufs) {
  302. gguf_free(_ctx_gguf);
  303. }
  304. gguf_free(ctx_out);
  305. fout.close();
  306. exit(1);
  307. }
  308. split_prefix = split_name.substr(0, split_name.size() - strlen(i_split_format));
  309. const char * split_name_c_str = split_name.c_str();
  310. int n_part = sscanf(&split_name_c_str[0] + split_prefix.size(), "-%d-of-%d", &i_split_file, &n_split_file);
  311. if (n_part != 2 || i_split_file - 1 != i_split || n_split_file != n_split) {
  312. fprintf(stderr, "\n%s: unexpected input file name: %s"
  313. " i_split=%d i_split_file=%d"
  314. " n_split=%d n_split_file=%d\n", __func__,
  315. split_params.input.c_str(),
  316. i_split, i_split_file,
  317. n_split, n_split_file);
  318. for (auto * _ctx_gguf : ctx_ggufs) {
  319. gguf_free(_ctx_gguf);
  320. }
  321. gguf_free(ctx_out);
  322. fout.close();
  323. exit(1);
  324. }
  325. }
  326. auto n_tensors = gguf_get_n_tensors(ctx_gguf);
  327. for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
  328. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  329. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  330. gguf_add_tensor(ctx_out, t);
  331. }
  332. total_tensors += n_tensors;
  333. fprintf(stderr, "\033[3Ddone\n");
  334. }
  335. // placeholder for the meta data
  336. {
  337. auto meta_size = gguf_get_meta_size(ctx_out);
  338. ::zeros(fout, meta_size);
  339. }
  340. // Write tensors data
  341. for (int i_split = 0; i_split < n_split; i_split++) {
  342. auto split_name = split_file_name(split_prefix, i_split, n_split);
  343. std::ifstream f_input(split_name.c_str(), std::ios::binary);
  344. if (!f_input.is_open()) {
  345. fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_name.c_str());
  346. for (auto * _ctx_gguf : ctx_ggufs) {
  347. gguf_free(_ctx_gguf);
  348. }
  349. gguf_free(ctx_out);
  350. fout.close();
  351. exit(1);
  352. }
  353. fprintf(stderr, "%s: writing tensors %s ...", __func__, split_name.c_str());
  354. auto * ctx_gguf = ctx_ggufs[i_split];
  355. auto * ctx_meta = ctx_metas[i_split];
  356. auto n_tensors = gguf_get_n_tensors(ctx_gguf);
  357. for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
  358. const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
  359. struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
  360. auto n_bytes = ggml_nbytes(t);
  361. if (read_data.size() < n_bytes) {
  362. read_data.resize(n_bytes);
  363. }
  364. auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
  365. f_input.seekg(offset);
  366. f_input.read((char *)read_data.data(), n_bytes);
  367. // write tensor data + padding
  368. fout.write((const char *)read_data.data(), n_bytes);
  369. zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
  370. }
  371. gguf_free(ctx_gguf);
  372. ggml_free(ctx_meta);
  373. f_input.close();
  374. fprintf(stderr, "\033[3Ddone\n");
  375. }
  376. {
  377. // go back to beginning of file and write the updated metadata
  378. fout.seekp(0);
  379. std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
  380. gguf_get_meta_data(ctx_out, data.data());
  381. fout.write((const char *)data.data(), data.size());
  382. fout.close();
  383. gguf_free(ctx_out);
  384. }
  385. fprintf(stderr, "%s: %s merged from %d split with %d tensors.\n",
  386. __func__, split_params.output.c_str(), n_split, total_tensors);
  387. }
  388. int main(int argc, const char ** argv) {
  389. if (argc < 3) {
  390. split_print_usage(argv[0]);
  391. }
  392. split_params params;
  393. split_params_parse(argc, argv, params);
  394. switch (params.operation) {
  395. case SPLIT_OP_SPLIT: gguf_split(params);
  396. break;
  397. case SPLIT_OP_MERGE: gguf_merge(params);
  398. break;
  399. default:split_print_usage(argv[0]);
  400. exit(1);
  401. }
  402. return 0;
  403. }