llama-quant.cpp 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934
  1. #include "llama-quant.h"
  2. #include "llama-impl.h"
  3. #include "llama-model.h"
  4. #include "llama-model-loader.h"
  5. #include <algorithm>
  6. #include <cmath>
  7. #include <cstring>
  8. #include <cinttypes>
  9. #include <fstream>
  10. #include <mutex>
  11. #include <thread>
  12. #include <unordered_map>
  13. static void zeros(std::ofstream & file, size_t n) {
  14. char zero = 0;
  15. for (size_t i = 0; i < n; ++i) {
  16. file.write(&zero, 1);
  17. }
  18. }
  19. struct quantize_state_impl {
  20. const llama_model & model;
  21. const llama_model_quantize_params * params;
  22. int n_attention_wv = 0;
  23. int n_ffn_down = 0;
  24. int n_ffn_gate = 0;
  25. int n_ffn_up = 0;
  26. int i_attention_wv = 0;
  27. int i_ffn_down = 0;
  28. int i_ffn_gate = 0;
  29. int i_ffn_up = 0;
  30. int n_k_quantized = 0;
  31. int n_fallback = 0;
  32. bool has_imatrix = false;
  33. // used to figure out if a model shares tok_embd with the output weight
  34. bool has_output = false;
  35. quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
  36. : model(model)
  37. , params(params)
  38. {}
  39. };
  40. static void llama_tensor_dequantize_impl(
  41. struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
  42. const size_t nelements, const int nthread
  43. ) {
  44. if (output.size() < nelements) {
  45. output.resize(nelements);
  46. }
  47. float * f32_output = (float *) output.data();
  48. const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
  49. if (ggml_is_quantized(tensor->type)) {
  50. if (qtype->to_float == NULL) {
  51. throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
  52. }
  53. } else if (tensor->type != GGML_TYPE_F16 &&
  54. tensor->type != GGML_TYPE_BF16) {
  55. throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
  56. }
  57. if (nthread < 2) {
  58. if (tensor->type == GGML_TYPE_F16) {
  59. ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
  60. } else if (tensor->type == GGML_TYPE_BF16) {
  61. ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
  62. } else if (ggml_is_quantized(tensor->type)) {
  63. qtype->to_float(tensor->data, f32_output, nelements);
  64. } else {
  65. GGML_ABORT("fatal error"); // unreachable
  66. }
  67. return;
  68. }
  69. size_t block_size;
  70. if (tensor->type == GGML_TYPE_F16 ||
  71. tensor->type == GGML_TYPE_BF16) {
  72. block_size = 1;
  73. } else {
  74. block_size = (size_t)ggml_blck_size(tensor->type);
  75. }
  76. size_t block_size_bytes = ggml_type_size(tensor->type);
  77. GGML_ASSERT(nelements % block_size == 0);
  78. size_t nblocks = nelements / block_size;
  79. size_t blocks_per_thread = nblocks / nthread;
  80. size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
  81. size_t in_buff_offs = 0;
  82. size_t out_buff_offs = 0;
  83. for (int tnum = 0; tnum < nthread; tnum++) {
  84. size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
  85. size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
  86. size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
  87. auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
  88. if (typ == GGML_TYPE_F16) {
  89. ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
  90. } else if (typ == GGML_TYPE_BF16) {
  91. ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
  92. } else {
  93. qtype->to_float(inbuf, outbuf, nels);
  94. }
  95. };
  96. workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
  97. in_buff_offs += thr_block_bytes;
  98. out_buff_offs += thr_elems;
  99. }
  100. for (auto & w : workers) { w.join(); }
  101. workers.clear();
  102. }
  103. static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
  104. const std::string name = ggml_get_name(tensor);
  105. // TODO: avoid hardcoded tensor names - use the TN_* constants
  106. const llm_arch arch = qs.model.arch;
  107. const auto tn = LLM_TN(arch);
  108. auto use_more_bits = [](int i_layer, int n_layers) -> bool {
  109. return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
  110. };
  111. const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
  112. auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
  113. if (n_expert > 1) {
  114. // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
  115. // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
  116. // for getting the current layer as I initially thought, and we need to resort to parsing the
  117. // tensor name.
  118. if (sscanf(name, "blk.%d.", &i_layer) != 1) {
  119. throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
  120. }
  121. if (i_layer < 0 || i_layer >= n_layer) {
  122. throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
  123. }
  124. }
  125. return std::make_pair(i_layer, n_layer);
  126. };
  127. // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
  128. // with the quantization of the output tensor
  129. if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
  130. if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
  131. new_type = qs.params->output_tensor_type;
  132. } else {
  133. const int64_t nx = tensor->ne[0];
  134. const int64_t qk_k = ggml_blck_size(new_type);
  135. if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
  136. new_type = GGML_TYPE_Q8_0;
  137. }
  138. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  139. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
  140. ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  141. new_type = GGML_TYPE_Q5_K;
  142. }
  143. else if (new_type != GGML_TYPE_Q8_0) {
  144. new_type = GGML_TYPE_Q6_K;
  145. }
  146. }
  147. } else if (name == "token_embd.weight") {
  148. if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
  149. new_type = qs.params->token_embedding_type;
  150. } else {
  151. if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
  152. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  153. new_type = GGML_TYPE_Q2_K;
  154. }
  155. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
  156. new_type = GGML_TYPE_IQ3_S;
  157. }
  158. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  159. new_type = GGML_TYPE_IQ3_S;
  160. }
  161. else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
  162. new_type = GGML_TYPE_Q4_K;
  163. }
  164. }
  165. } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
  166. ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  167. if (name.find("attn_v.weight") != std::string::npos) {
  168. if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
  169. else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  170. ++qs.i_attention_wv;
  171. }
  172. else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
  173. new_type = GGML_TYPE_Q4_K;
  174. }
  175. else if (name.find("ffn_down") != std::string::npos) {
  176. if (qs.i_ffn_down < qs.n_ffn_down/8) {
  177. new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  178. }
  179. ++qs.i_ffn_down;
  180. }
  181. else if (name.find("attn_output.weight") != std::string::npos) {
  182. if (qs.model.hparams.n_expert == 8) {
  183. new_type = GGML_TYPE_Q5_K;
  184. } else {
  185. if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
  186. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
  187. }
  188. }
  189. } else if (name.find("attn_v.weight") != std::string::npos) {
  190. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
  191. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  192. }
  193. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
  194. new_type = GGML_TYPE_Q4_K;
  195. }
  196. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  197. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
  198. }
  199. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
  200. new_type = GGML_TYPE_Q4_K;
  201. }
  202. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  203. new_type = GGML_TYPE_Q4_K;
  204. }
  205. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  206. new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  207. }
  208. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
  209. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
  210. new_type = GGML_TYPE_Q5_K;
  211. }
  212. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
  213. use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
  214. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
  215. if (qs.model.type == LLM_TYPE_70B) {
  216. // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
  217. // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
  218. // nearly negligible increase in model size by quantizing this tensor with more bits:
  219. if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
  220. }
  221. if (qs.model.hparams.n_expert == 8) {
  222. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  223. // TODO: explore better strategies
  224. new_type = GGML_TYPE_Q8_0;
  225. }
  226. ++qs.i_attention_wv;
  227. } else if (name.find("attn_k.weight") != std::string::npos) {
  228. if (qs.model.hparams.n_expert == 8) {
  229. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  230. // TODO: explore better strategies
  231. new_type = GGML_TYPE_Q8_0;
  232. }
  233. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  234. new_type = GGML_TYPE_IQ3_XXS;
  235. }
  236. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  237. new_type = GGML_TYPE_IQ2_S;
  238. }
  239. } else if (name.find("attn_q.weight") != std::string::npos) {
  240. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  241. new_type = GGML_TYPE_IQ3_XXS;
  242. }
  243. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  244. new_type = GGML_TYPE_IQ2_S;
  245. }
  246. } else if (name.find("ffn_down") != std::string::npos) {
  247. auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
  248. int i_layer = info.first, n_layer = info.second;
  249. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  250. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
  251. if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
  252. }
  253. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
  254. new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  255. }
  256. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  257. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
  258. : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
  259. : GGML_TYPE_Q3_K;
  260. }
  261. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
  262. (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
  263. new_type = GGML_TYPE_Q4_K;
  264. }
  265. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
  266. new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
  267. }
  268. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
  269. if (arch == LLM_ARCH_FALCON) {
  270. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
  271. use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  272. } else {
  273. if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  274. }
  275. }
  276. else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
  277. new_type = GGML_TYPE_Q5_K;
  278. }
  279. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  280. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
  281. new_type = GGML_TYPE_Q5_K;
  282. }
  283. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
  284. && qs.has_imatrix && i_layer < n_layer/8) {
  285. // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
  286. // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
  287. // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
  288. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
  289. }
  290. ++qs.i_ffn_down;
  291. } else if (name.find("attn_output.weight") != std::string::npos) {
  292. if (arch != LLM_ARCH_FALCON) {
  293. if (qs.model.hparams.n_expert == 8) {
  294. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  295. ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
  296. ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
  297. ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
  298. new_type = GGML_TYPE_Q5_K;
  299. }
  300. } else {
  301. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
  302. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
  303. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
  304. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
  305. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
  306. }
  307. } else {
  308. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
  309. }
  310. }
  311. else if (name.find("attn_qkv.weight") != std::string::npos) {
  312. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  313. new_type = GGML_TYPE_Q4_K;
  314. }
  315. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
  316. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
  317. }
  318. else if (name.find("ffn_gate") != std::string::npos) {
  319. auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
  320. int i_layer = info.first, n_layer = info.second;
  321. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  322. new_type = GGML_TYPE_IQ3_XXS;
  323. }
  324. ++qs.i_ffn_gate;
  325. }
  326. else if (name.find("ffn_up") != std::string::npos) {
  327. auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
  328. int i_layer = info.first, n_layer = info.second;
  329. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  330. new_type = GGML_TYPE_IQ3_XXS;
  331. }
  332. ++qs.i_ffn_up;
  333. }
  334. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  335. //}
  336. // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
  337. //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
  338. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  339. //}
  340. // This can be used to reduce the size of the Q5_K_S model.
  341. // The associated PPL increase is fully in line with the size reduction
  342. //else {
  343. // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
  344. //}
  345. bool convert_incompatible_tensor = false;
  346. {
  347. const int64_t nx = tensor->ne[0];
  348. const int64_t ny = tensor->ne[1];
  349. const int64_t qk_k = ggml_blck_size(new_type);
  350. if (nx % qk_k != 0) {
  351. LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
  352. convert_incompatible_tensor = true;
  353. } else {
  354. ++qs.n_k_quantized;
  355. }
  356. }
  357. if (convert_incompatible_tensor) {
  358. switch (new_type) {
  359. case GGML_TYPE_TQ1_0:
  360. case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
  361. case GGML_TYPE_IQ2_XXS:
  362. case GGML_TYPE_IQ2_XS:
  363. case GGML_TYPE_IQ2_S:
  364. case GGML_TYPE_IQ3_XXS:
  365. case GGML_TYPE_IQ3_S:
  366. case GGML_TYPE_IQ1_S:
  367. case GGML_TYPE_IQ1_M:
  368. case GGML_TYPE_Q2_K:
  369. case GGML_TYPE_Q3_K:
  370. case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
  371. case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
  372. case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
  373. case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
  374. default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
  375. }
  376. if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
  377. new_type = GGML_TYPE_F16;
  378. }
  379. LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
  380. ++qs.n_fallback;
  381. }
  382. return new_type;
  383. }
  384. static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
  385. if (nthread < 2) {
  386. // single-thread
  387. size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
  388. if (!ggml_validate_row_data(new_type, new_data, new_size)) {
  389. throw std::runtime_error("quantized data validation failed");
  390. }
  391. return new_size;
  392. }
  393. std::mutex mutex;
  394. int64_t counter = 0;
  395. size_t new_size = 0;
  396. bool valid = true;
  397. auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
  398. nrows, n_per_row, imatrix]() {
  399. const int64_t nrows_per_chunk = chunk_size / n_per_row;
  400. size_t local_size = 0;
  401. while (true) {
  402. std::unique_lock<std::mutex> lock(mutex);
  403. int64_t first_row = counter; counter += nrows_per_chunk;
  404. if (first_row >= nrows) {
  405. if (local_size > 0) {
  406. new_size += local_size;
  407. }
  408. break;
  409. }
  410. lock.unlock();
  411. const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
  412. size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
  413. local_size += this_size;
  414. // validate the quantized data
  415. const size_t row_size = ggml_row_size(new_type, n_per_row);
  416. void * this_data = (char *) new_data + first_row * row_size;
  417. if (!ggml_validate_row_data(new_type, this_data, this_size)) {
  418. std::unique_lock<std::mutex> lock(mutex);
  419. valid = false;
  420. break;
  421. }
  422. }
  423. };
  424. for (int it = 0; it < nthread - 1; ++it) {
  425. workers.emplace_back(compute);
  426. }
  427. compute();
  428. for (auto & w : workers) { w.join(); }
  429. workers.clear();
  430. if (!valid) {
  431. throw std::runtime_error("quantized data validation failed");
  432. }
  433. return new_size;
  434. }
  435. static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
  436. ggml_type default_type;
  437. llama_ftype ftype = params->ftype;
  438. switch (params->ftype) {
  439. case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
  440. case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
  441. case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
  442. case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
  443. case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
  444. case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
  445. case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
  446. case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
  447. // K-quants
  448. case LLAMA_FTYPE_MOSTLY_Q2_K_S:
  449. case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
  450. case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
  451. case LLAMA_FTYPE_MOSTLY_Q3_K_S:
  452. case LLAMA_FTYPE_MOSTLY_Q3_K_M:
  453. case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
  454. case LLAMA_FTYPE_MOSTLY_Q4_K_S:
  455. case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
  456. case LLAMA_FTYPE_MOSTLY_Q5_K_S:
  457. case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
  458. case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
  459. case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
  460. case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
  461. case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
  462. case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
  463. case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
  464. case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
  465. case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
  466. case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
  467. case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
  468. case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
  469. case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
  470. case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
  471. case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
  472. default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
  473. }
  474. int nthread = params->nthread;
  475. if (nthread <= 0) {
  476. nthread = std::thread::hardware_concurrency();
  477. }
  478. // mmap consistently increases speed Linux, and also increases speed on Windows with
  479. // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
  480. #if defined(__linux__) || defined(_WIN32)
  481. constexpr bool use_mmap = true;
  482. #else
  483. constexpr bool use_mmap = false;
  484. #endif
  485. llama_model_kv_override * kv_overrides = nullptr;
  486. if (params->kv_overrides) {
  487. auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
  488. kv_overrides = v->data();
  489. }
  490. std::vector<std::string> splits = {};
  491. llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
  492. ml.init_mappings(false); // no prefetching
  493. llama_model model(llama_model_default_params());
  494. model.load_arch (ml);
  495. model.load_hparams(ml);
  496. model.load_stats (ml);
  497. struct quantize_state_impl qs(model, params);
  498. if (params->only_copy) {
  499. ftype = ml.ftype;
  500. }
  501. const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
  502. if (params->imatrix) {
  503. imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
  504. if (imatrix_data) {
  505. LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
  506. qs.has_imatrix = true;
  507. // check imatrix for nans or infs
  508. for (const auto & kv : *imatrix_data) {
  509. for (float f : kv.second) {
  510. if (!std::isfinite(f)) {
  511. throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
  512. }
  513. }
  514. }
  515. }
  516. }
  517. const size_t align = GGUF_DEFAULT_ALIGNMENT;
  518. gguf_context_ptr ctx_out { gguf_init_empty() };
  519. // copy the KV pairs from the input file
  520. gguf_set_kv (ctx_out.get(), ml.meta.get());
  521. gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
  522. gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
  523. // Remove split metadata
  524. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
  525. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
  526. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
  527. if (params->kv_overrides) {
  528. const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
  529. for (const auto & o : overrides) {
  530. if (o.key[0] == 0) break;
  531. if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
  532. gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
  533. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
  534. gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
  535. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
  536. gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
  537. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
  538. gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
  539. } else {
  540. LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
  541. }
  542. }
  543. }
  544. // make a list of weights
  545. std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
  546. tensors.reserve(ml.weights_map.size());
  547. for (const auto & it : ml.weights_map) {
  548. tensors.push_back(&it.second);
  549. }
  550. // keep_split requires that the weights are sorted by split index
  551. if (params->keep_split) {
  552. std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
  553. if (a->idx == b->idx) {
  554. return a->offs < b->offs;
  555. }
  556. return a->idx < b->idx;
  557. });
  558. }
  559. for (const auto * it : tensors) {
  560. const struct ggml_tensor * tensor = it->tensor;
  561. const std::string name = ggml_get_name(tensor);
  562. // TODO: avoid hardcoded tensor names - use the TN_* constants
  563. if (name.find("attn_v.weight") != std::string::npos ||
  564. name.find("attn_qkv.weight") != std::string::npos ||
  565. name.find("attn_kv_b.weight")!= std::string::npos) {
  566. ++qs.n_attention_wv;
  567. } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
  568. qs.has_output = true;
  569. }
  570. }
  571. qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
  572. // sanity checks for models that have attention layers
  573. if (qs.n_attention_wv != 0)
  574. {
  575. const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
  576. // attention layers have a non-zero number of kv heads
  577. int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
  578. if (llama_model_has_encoder(&model)) {
  579. n_attn_layer *= 3;
  580. }
  581. GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
  582. }
  583. size_t total_size_org = 0;
  584. size_t total_size_new = 0;
  585. std::vector<std::thread> workers;
  586. workers.reserve(nthread);
  587. int idx = 0;
  588. std::vector<no_init<uint8_t>> read_data;
  589. std::vector<no_init<uint8_t>> work;
  590. std::vector<no_init<float>> f32_conv_buf;
  591. uint16_t n_split = 1;
  592. // Assume split index is continuous
  593. if (params->keep_split) {
  594. for (const auto * it : tensors) {
  595. n_split = std::max(uint16_t(it->idx + 1), n_split);
  596. }
  597. }
  598. std::vector<gguf_context_ptr> ctx_outs(n_split);
  599. ctx_outs[0] = std::move(ctx_out);
  600. // populate the original tensors so we get an initial meta data
  601. for (const auto * it : tensors) {
  602. uint16_t i_split = params->keep_split ? it->idx : 0;
  603. struct ggml_tensor * tensor = it->tensor;
  604. if (!ctx_outs[i_split]) {
  605. ctx_outs[i_split].reset(gguf_init_empty());
  606. }
  607. gguf_add_tensor(ctx_outs[i_split].get(), tensor);
  608. }
  609. // Set split info if needed
  610. if (n_split > 1) {
  611. for (size_t i = 0; i < ctx_outs.size(); ++i) {
  612. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
  613. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
  614. gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
  615. }
  616. }
  617. int cur_split = -1;
  618. std::ofstream fout;
  619. auto close_ofstream = [&]() {
  620. // Write metadata and close file handler
  621. if (fout.is_open()) {
  622. fout.seekp(0);
  623. std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
  624. gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
  625. fout.write((const char *) data.data(), data.size());
  626. fout.close();
  627. }
  628. };
  629. auto new_ofstream = [&](int index) {
  630. cur_split = index;
  631. GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
  632. std::string fname = fname_out;
  633. if (params->keep_split) {
  634. std::vector<char> split_path(llama_path_max(), 0);
  635. llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
  636. fname = std::string(split_path.data());
  637. }
  638. fout = std::ofstream(fname, std::ios::binary);
  639. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  640. const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
  641. // placeholder for the meta data
  642. ::zeros(fout, meta_size);
  643. };
  644. const auto tn = LLM_TN(model.arch);
  645. new_ofstream(0);
  646. for (const auto * it : tensors) {
  647. const auto & weight = *it;
  648. struct ggml_tensor * tensor = weight.tensor;
  649. if (weight.idx != cur_split && params->keep_split) {
  650. close_ofstream();
  651. new_ofstream(weight.idx);
  652. }
  653. const std::string name = ggml_get_name(tensor);
  654. if (!ml.use_mmap) {
  655. if (read_data.size() < ggml_nbytes(tensor)) {
  656. read_data.resize(ggml_nbytes(tensor));
  657. }
  658. tensor->data = read_data.data();
  659. }
  660. ml.load_data_for(tensor);
  661. LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
  662. ++idx, ml.n_tensors,
  663. ggml_get_name(tensor),
  664. llama_format_tensor_shape(tensor).c_str(),
  665. ggml_type_name(tensor->type));
  666. // This used to be a regex, but <regex> has an extreme cost to compile times.
  667. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
  668. // quantize only 2D and 3D tensors (experts)
  669. quantize &= (ggml_n_dims(tensor) >= 2);
  670. // do not quantize norm tensors
  671. quantize &= name.find("_norm.weight") == std::string::npos;
  672. quantize &= params->quantize_output_tensor || name != "output.weight";
  673. quantize &= !params->only_copy;
  674. // do not quantize expert gating tensors
  675. // NOTE: can't use LLM_TN here because the layer number is not known
  676. quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
  677. // do not quantize positional embeddings and token types (BERT)
  678. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
  679. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
  680. // do not quantize Mamba's small yet 2D weights
  681. // NOTE: can't use LLM_TN here because the layer number is not known
  682. quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
  683. // do not quantize RWKV's time_mix_first tensors
  684. quantize &= name.find("time_mix_first.weight") == std::string::npos;
  685. quantize &= name.find("time_mix_w1.weight") == std::string::npos;
  686. quantize &= name.find("time_mix_w2.weight") == std::string::npos;
  687. quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
  688. quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
  689. quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
  690. // do not quantize relative position bias (T5)
  691. quantize &= name.find("attn_rel_b.weight") == std::string::npos;
  692. enum ggml_type new_type;
  693. void * new_data;
  694. size_t new_size;
  695. if (quantize) {
  696. new_type = default_type;
  697. // get more optimal quantization type based on the tensor shape, layer, etc.
  698. if (!params->pure && ggml_is_quantized(default_type)) {
  699. new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
  700. }
  701. if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
  702. new_type = params->token_embedding_type;
  703. }
  704. if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
  705. new_type = params->output_tensor_type;
  706. }
  707. // If we've decided to quantize to the same type the tensor is already
  708. // in then there's nothing to do.
  709. quantize = tensor->type != new_type;
  710. }
  711. if (!quantize) {
  712. new_type = tensor->type;
  713. new_data = tensor->data;
  714. new_size = ggml_nbytes(tensor);
  715. LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
  716. } else {
  717. const int64_t nelements = ggml_nelements(tensor);
  718. const float * imatrix = nullptr;
  719. if (imatrix_data) {
  720. auto it = imatrix_data->find(tensor->name);
  721. if (it == imatrix_data->end()) {
  722. LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
  723. } else {
  724. if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
  725. imatrix = it->second.data();
  726. } else {
  727. LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
  728. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
  729. // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
  730. // this is a significant error and it may be good idea to abort the process if this happens,
  731. // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
  732. // tok_embd should be ignored in this case, since it always causes this warning
  733. if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
  734. throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
  735. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
  736. }
  737. }
  738. }
  739. }
  740. if ((new_type == GGML_TYPE_IQ2_XXS ||
  741. new_type == GGML_TYPE_IQ2_XS ||
  742. new_type == GGML_TYPE_IQ2_S ||
  743. new_type == GGML_TYPE_IQ1_S ||
  744. (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
  745. (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
  746. LLAMA_LOG_ERROR("\n\n============================================================\n");
  747. LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
  748. LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
  749. LLAMA_LOG_ERROR("============================================================\n\n");
  750. throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
  751. }
  752. float * f32_data;
  753. if (tensor->type == GGML_TYPE_F32) {
  754. f32_data = (float *) tensor->data;
  755. } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
  756. throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
  757. } else {
  758. llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
  759. f32_data = (float *) f32_conv_buf.data();
  760. }
  761. LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
  762. fflush(stdout);
  763. if (work.size() < (size_t)nelements * 4) {
  764. work.resize(nelements * 4); // upper bound on size
  765. }
  766. new_data = work.data();
  767. const int64_t n_per_row = tensor->ne[0];
  768. const int64_t nrows = tensor->ne[1];
  769. static const int64_t min_chunk_size = 32 * 512;
  770. const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
  771. const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
  772. const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
  773. const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
  774. // quantize each expert separately since they have different importance matrices
  775. new_size = 0;
  776. for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
  777. const float * f32_data_03 = f32_data + i03 * nelements_matrix;
  778. void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
  779. const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
  780. new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
  781. }
  782. LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
  783. }
  784. total_size_org += ggml_nbytes(tensor);
  785. total_size_new += new_size;
  786. // update the gguf meta data as we go
  787. gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
  788. GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
  789. gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
  790. // write tensor data + padding
  791. fout.write((const char *) new_data, new_size);
  792. zeros(fout, GGML_PAD(new_size, align) - new_size);
  793. }
  794. close_ofstream();
  795. LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
  796. LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
  797. if (qs.n_fallback > 0) {
  798. LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
  799. __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
  800. }
  801. }
  802. //
  803. // interface implementation
  804. //
  805. struct llama_model_quantize_params llama_model_quantize_default_params() {
  806. struct llama_model_quantize_params result = {
  807. /*.nthread =*/ 0,
  808. /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
  809. /*.output_tensor_type =*/ GGML_TYPE_COUNT,
  810. /*.token_embedding_type =*/ GGML_TYPE_COUNT,
  811. /*.allow_requantize =*/ false,
  812. /*.quantize_output_tensor =*/ true,
  813. /*.only_copy =*/ false,
  814. /*.pure =*/ false,
  815. /*.keep_split =*/ false,
  816. /*.imatrix =*/ nullptr,
  817. /*.kv_overrides =*/ nullptr,
  818. };
  819. return result;
  820. }
  821. uint32_t llama_model_quantize(
  822. const char * fname_inp,
  823. const char * fname_out,
  824. const llama_model_quantize_params * params) {
  825. try {
  826. llama_model_quantize_impl(fname_inp, fname_out, params);
  827. } catch (const std::exception & err) {
  828. LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
  829. return 1;
  830. }
  831. return 0;
  832. }