1
0

quantize.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #include "build-info.h"
  2. #include "llama.h"
  3. #include <cstdio>
  4. #include <cstring>
  5. #include <vector>
  6. #include <string>
  7. struct quant_option {
  8. std::string name;
  9. llama_ftype ftype;
  10. std::string desc;
  11. };
  12. static const std::vector<struct quant_option> QUANT_OPTIONS = {
  13. {
  14. "Q4_0",
  15. LLAMA_FTYPE_MOSTLY_Q4_0,
  16. " 3.50G, +0.2499 ppl @ 7B - small, very high quality loss - legacy, prefer using Q3_K_M",
  17. },
  18. {
  19. "Q4_1",
  20. LLAMA_FTYPE_MOSTLY_Q4_1,
  21. " 3.90G, +0.1846 ppl @ 7B - small, substantial quality loss - legacy, prefer using Q3_K_L",
  22. },
  23. {
  24. "Q5_0",
  25. LLAMA_FTYPE_MOSTLY_Q5_0,
  26. " 4.30G, +0.0796 ppl @ 7B - medium, balanced quality - legacy, prefer using Q4_K_M",
  27. },
  28. {
  29. "Q5_1",
  30. LLAMA_FTYPE_MOSTLY_Q5_1,
  31. " 4.70G, +0.0415 ppl @ 7B - medium, low quality loss - legacy, prefer using Q5_K_M",
  32. },
  33. #ifdef GGML_USE_K_QUANTS
  34. {
  35. "Q2_K",
  36. LLAMA_FTYPE_MOSTLY_Q2_K,
  37. " 2.67G, +0.8698 ppl @ 7B - smallest, extreme quality loss - not recommended",
  38. },
  39. {
  40. "Q3_K",
  41. LLAMA_FTYPE_MOSTLY_Q3_K_M,
  42. "alias for Q3_K_M"
  43. },
  44. {
  45. "Q3_K_S",
  46. LLAMA_FTYPE_MOSTLY_Q3_K_S,
  47. " 2.75G, +0.5505 ppl @ 7B - very small, very high quality loss",
  48. },
  49. {
  50. "Q3_K_M",
  51. LLAMA_FTYPE_MOSTLY_Q3_K_M,
  52. " 3.06G, +0.2437 ppl @ 7B - very small, very high quality loss",
  53. },
  54. {
  55. "Q3_K_L",
  56. LLAMA_FTYPE_MOSTLY_Q3_K_L,
  57. " 3.35G, +0.1803 ppl @ 7B - small, substantial quality loss",
  58. },
  59. {
  60. "Q4_K",
  61. LLAMA_FTYPE_MOSTLY_Q4_K_M,
  62. "alias for Q4_K_M",
  63. },
  64. {
  65. "Q4_K_S",
  66. LLAMA_FTYPE_MOSTLY_Q4_K_S,
  67. " 3.56G, +0.1149 ppl @ 7B - small, significant quality loss",
  68. },
  69. {
  70. "Q4_K_M",
  71. LLAMA_FTYPE_MOSTLY_Q4_K_M,
  72. " 3.80G, +0.0535 ppl @ 7B - medium, balanced quality - *recommended*",
  73. },
  74. {
  75. "Q5_K",
  76. LLAMA_FTYPE_MOSTLY_Q5_K_M,
  77. "alias for Q5_K_M",
  78. },
  79. {
  80. "Q5_K_S",
  81. LLAMA_FTYPE_MOSTLY_Q5_K_S,
  82. " 4.33G, +0.0353 ppl @ 7B - large, low quality loss - *recommended*",
  83. },
  84. {
  85. "Q5_K_M",
  86. LLAMA_FTYPE_MOSTLY_Q5_K_M,
  87. " 4.45G, +0.0142 ppl @ 7B - large, very low quality loss - *recommended*",
  88. },
  89. {
  90. "Q6_K",
  91. LLAMA_FTYPE_MOSTLY_Q6_K,
  92. " 5.15G, +0.0044 ppl @ 7B - very large, extremely low quality loss",
  93. },
  94. #endif
  95. {
  96. "Q8_0",
  97. LLAMA_FTYPE_MOSTLY_Q8_0,
  98. " 6.70G, +0.0004 ppl @ 7B - very large, extremely low quality loss - not recommended",
  99. },
  100. {
  101. "F16",
  102. LLAMA_FTYPE_MOSTLY_F16,
  103. "13.00G @ 7B - extremely large, virtually no quality loss - not recommended",
  104. },
  105. {
  106. "F32",
  107. LLAMA_FTYPE_ALL_F32,
  108. "26.00G @ 7B - absolutely huge, lossless - not recommended",
  109. },
  110. };
  111. bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
  112. std::string ftype_str;
  113. for (auto ch : ftype_str_in) {
  114. ftype_str.push_back(std::toupper(ch));
  115. }
  116. for (auto & it : QUANT_OPTIONS) {
  117. if (it.name == ftype_str) {
  118. ftype = it.ftype;
  119. ftype_str_out = it.name;
  120. return true;
  121. }
  122. }
  123. try {
  124. int ftype_int = std::stoi(ftype_str);
  125. for (auto & it : QUANT_OPTIONS) {
  126. if (it.ftype == ftype_int) {
  127. ftype = it.ftype;
  128. ftype_str_out = it.name;
  129. return true;
  130. }
  131. }
  132. }
  133. catch (...) {
  134. // stoi failed
  135. }
  136. return false;
  137. }
  138. // usage:
  139. // ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
  140. //
  141. void usage(const char * executable) {
  142. fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n\n", executable);
  143. fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
  144. fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
  145. fprintf(stderr, "\nAllowed quantization types:\n");
  146. for (auto & it : QUANT_OPTIONS) {
  147. printf(" %2d or %-6s : %s\n", it.ftype, it.name.c_str(), it.desc.c_str());
  148. }
  149. exit(1);
  150. }
  151. int main(int argc, char ** argv) {
  152. if (argc < 3) {
  153. usage(argv[0]);
  154. }
  155. llama_model_quantize_params params = llama_model_quantize_default_params();
  156. int arg_idx = 1;
  157. for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
  158. if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
  159. params.quantize_output_tensor = false;
  160. } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
  161. params.allow_requantize = true;
  162. } else {
  163. usage(argv[0]);
  164. }
  165. }
  166. if (argc - arg_idx < 3) {
  167. usage(argv[0]);
  168. }
  169. llama_backend_init(false);
  170. // parse command line arguments
  171. const std::string fname_inp = argv[arg_idx];
  172. arg_idx++;
  173. std::string fname_out;
  174. std::string ftype_str;
  175. if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
  176. std::string fpath;
  177. const size_t pos = fname_inp.find_last_of('/');
  178. if (pos != std::string::npos) {
  179. fpath = fname_inp.substr(0, pos + 1);
  180. }
  181. // export as [inp path]/ggml-model-[ftype].bin
  182. fname_out = fpath + "ggml-model-" + ftype_str + ".bin";
  183. arg_idx++;
  184. }
  185. else {
  186. fname_out = argv[arg_idx];
  187. arg_idx++;
  188. if (argc <= arg_idx) {
  189. fprintf(stderr, "%s: missing ftype\n", __func__);
  190. return 1;
  191. }
  192. if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
  193. fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
  194. return 1;
  195. }
  196. arg_idx++;
  197. }
  198. // parse nthreads
  199. if (argc > arg_idx) {
  200. try {
  201. params.nthread = std::stoi(argv[arg_idx]);
  202. }
  203. catch (const std::exception & e) {
  204. fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
  205. return 1;
  206. }
  207. }
  208. fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
  209. fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
  210. if (params.nthread > 0) {
  211. fprintf(stderr, " using %d threads", params.nthread);
  212. }
  213. fprintf(stderr, "\n");
  214. const int64_t t_main_start_us = llama_time_us();
  215. int64_t t_quantize_us = 0;
  216. // load the model
  217. {
  218. const int64_t t_start_us = llama_time_us();
  219. if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), &params)) {
  220. fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
  221. return 1;
  222. }
  223. t_quantize_us = llama_time_us() - t_start_us;
  224. }
  225. // report timing
  226. {
  227. const int64_t t_main_end_us = llama_time_us();
  228. printf("\n");
  229. printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
  230. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
  231. }
  232. llama_backend_free();
  233. return 0;
  234. }