1
0

quantize.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #include "build-info.h"
  4. #include <cstdio>
  5. #include <map>
  6. #include <string>
  7. static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
  8. {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0},
  9. {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
  10. {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
  11. {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
  12. {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
  13. };
  14. bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
  15. auto it = LLAMA_FTYPE_MAP.find(ftype_str);
  16. if (it != LLAMA_FTYPE_MAP.end()) {
  17. ftype = it->second;
  18. ftype_str_out = it->first;
  19. return true;
  20. }
  21. // try to parse as an integer
  22. try {
  23. int ftype_int = std::stoi(ftype_str);
  24. for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
  25. if (it->second == ftype_int) {
  26. ftype = it->second;
  27. ftype_str_out = it->first;
  28. return true;
  29. }
  30. }
  31. }
  32. catch (...) {
  33. // stoi failed
  34. }
  35. return false;
  36. }
  37. // usage:
  38. // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
  39. //
  40. int main(int argc, char ** argv) {
  41. ggml_time_init();
  42. if (argc < 3) {
  43. fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
  44. for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
  45. fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
  46. }
  47. return 1;
  48. }
  49. // needed to initialize f16 tables
  50. {
  51. struct ggml_init_params params = { 0, NULL, false };
  52. struct ggml_context * ctx = ggml_init(params);
  53. ggml_free(ctx);
  54. }
  55. // parse command line arguments
  56. const std::string fname_inp = argv[1];
  57. std::string fname_out;
  58. int nthread;
  59. llama_ftype ftype;
  60. int arg_idx = 2;
  61. std::string ftype_str;
  62. if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
  63. // argv[2] is the ftype
  64. std::string fpath;
  65. const size_t pos = fname_inp.find_last_of('/');
  66. if (pos != std::string::npos) {
  67. fpath = fname_inp.substr(0, pos + 1);
  68. }
  69. // export as [inp path]/ggml-model-[ftype].bin
  70. fname_out = fpath + "ggml-model-" + ftype_str + ".bin";
  71. arg_idx++;
  72. }
  73. else {
  74. // argv[2] is the output path
  75. fname_out = argv[arg_idx];
  76. arg_idx++;
  77. if (argc <= arg_idx) {
  78. fprintf(stderr, "%s: missing ftype\n", __func__);
  79. return 1;
  80. }
  81. // argv[3] is the ftype
  82. if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
  83. fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
  84. return 1;
  85. }
  86. arg_idx++;
  87. }
  88. // parse nthreads
  89. if (argc > arg_idx) {
  90. try {
  91. nthread = std::stoi(argv[arg_idx]);
  92. }
  93. catch (const std::exception & e) {
  94. fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
  95. return 1;
  96. }
  97. } else {
  98. nthread = 0;
  99. }
  100. fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
  101. fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
  102. if (nthread > 0) {
  103. fprintf(stderr, " using %d threads", nthread);
  104. }
  105. fprintf(stderr, "\n");
  106. const int64_t t_main_start_us = ggml_time_us();
  107. int64_t t_quantize_us = 0;
  108. // load the model
  109. {
  110. const int64_t t_start_us = ggml_time_us();
  111. if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
  112. fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
  113. return 1;
  114. }
  115. t_quantize_us = ggml_time_us() - t_start_us;
  116. }
  117. // report timing
  118. {
  119. const int64_t t_main_end_us = ggml_time_us();
  120. printf("\n");
  121. printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
  122. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
  123. }
  124. return 0;
  125. }