quantize.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #include "build-info.h"
  2. #include "llama.h"
  3. #include <cstdio>
  4. #include <map>
  5. #include <string>
  6. static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
  7. {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0},
  8. {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
  9. {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
  10. {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
  11. {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
  12. };
  13. bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
  14. auto it = LLAMA_FTYPE_MAP.find(ftype_str);
  15. if (it != LLAMA_FTYPE_MAP.end()) {
  16. ftype = it->second;
  17. ftype_str_out = it->first;
  18. return true;
  19. }
  20. // try to parse as an integer
  21. try {
  22. int ftype_int = std::stoi(ftype_str);
  23. for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
  24. if (it->second == ftype_int) {
  25. ftype = it->second;
  26. ftype_str_out = it->first;
  27. return true;
  28. }
  29. }
  30. }
  31. catch (...) {
  32. // stoi failed
  33. }
  34. return false;
  35. }
  36. // usage:
  37. // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
  38. //
  39. int main(int argc, char ** argv) {
  40. if (argc < 3) {
  41. fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
  42. for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
  43. fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
  44. }
  45. return 1;
  46. }
  47. llama_init_backend();
  48. // parse command line arguments
  49. const std::string fname_inp = argv[1];
  50. std::string fname_out;
  51. int nthread;
  52. llama_ftype ftype;
  53. int arg_idx = 2;
  54. std::string ftype_str;
  55. if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
  56. // argv[2] is the ftype
  57. std::string fpath;
  58. const size_t pos = fname_inp.find_last_of('/');
  59. if (pos != std::string::npos) {
  60. fpath = fname_inp.substr(0, pos + 1);
  61. }
  62. // export as [inp path]/ggml-model-[ftype].bin
  63. fname_out = fpath + "ggml-model-" + ftype_str + ".bin";
  64. arg_idx++;
  65. }
  66. else {
  67. // argv[2] is the output path
  68. fname_out = argv[arg_idx];
  69. arg_idx++;
  70. if (argc <= arg_idx) {
  71. fprintf(stderr, "%s: missing ftype\n", __func__);
  72. return 1;
  73. }
  74. // argv[3] is the ftype
  75. if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
  76. fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
  77. return 1;
  78. }
  79. arg_idx++;
  80. }
  81. // parse nthreads
  82. if (argc > arg_idx) {
  83. try {
  84. nthread = std::stoi(argv[arg_idx]);
  85. }
  86. catch (const std::exception & e) {
  87. fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
  88. return 1;
  89. }
  90. } else {
  91. nthread = 0;
  92. }
  93. fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
  94. fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
  95. if (nthread > 0) {
  96. fprintf(stderr, " using %d threads", nthread);
  97. }
  98. fprintf(stderr, "\n");
  99. const int64_t t_main_start_us = llama_time_us();
  100. int64_t t_quantize_us = 0;
  101. // load the model
  102. {
  103. const int64_t t_start_us = llama_time_us();
  104. if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
  105. fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
  106. return 1;
  107. }
  108. t_quantize_us = llama_time_us() - t_start_us;
  109. }
  110. // report timing
  111. {
  112. const int64_t t_main_end_us = llama_time_us();
  113. printf("\n");
  114. printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
  115. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
  116. }
  117. return 0;
  118. }