llama-quant.cpp 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087
  1. #include "llama-quant.h"
  2. #include "llama-impl.h"
  3. #include "llama-model.h"
  4. #include "llama-model-loader.h"
  5. #include <algorithm>
  6. #include <cmath>
  7. #include <cstring>
  8. #include <cinttypes>
  9. #include <fstream>
  10. #include <mutex>
  11. #include <regex>
  12. #include <thread>
  13. #include <unordered_map>
  14. // Quantization types. Changes to this struct must be replicated in quantize.cpp
  15. struct tensor_quantization {
  16. std::string name;
  17. ggml_type quant = GGML_TYPE_COUNT;
  18. };
  19. static void zeros(std::ofstream & file, size_t n) {
  20. char zero = 0;
  21. for (size_t i = 0; i < n; ++i) {
  22. file.write(&zero, 1);
  23. }
  24. }
  25. static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
  26. if (prune.empty()) {
  27. return orig_name;
  28. }
  29. static const std::regex pattern(R"(blk\.(\d+)\.)");
  30. if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  31. const int blk = std::stoi(match[1]);
  32. std::string new_name = orig_name;
  33. if (mapped.count(blk)) {
  34. // Already mapped, do nothing
  35. } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
  36. mapped[blk] = "";
  37. } else if (blk < prune.front()) {
  38. mapped[blk] = std::to_string(blk);
  39. next_id = blk + 1;
  40. } else {
  41. mapped[blk] = std::to_string(next_id);
  42. ++next_id;
  43. }
  44. return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
  45. }
  46. return orig_name;
  47. }
  48. static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
  49. if (mapped.empty()) {
  50. return orig_name;
  51. }
  52. static const std::regex pattern(R"(blk\.(\d+)\.)");
  53. if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  54. const std::string blk(match[1]);
  55. std::string new_name = orig_name;
  56. for (const auto & p : mapped) {
  57. if (p.second == blk) {
  58. LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
  59. return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
  60. }
  61. }
  62. GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
  63. }
  64. return orig_name;
  65. }
  66. struct quantize_state_impl {
  67. const llama_model & model;
  68. const llama_model_quantize_params * params;
  69. int n_attention_wv = 0;
  70. int n_ffn_down = 0;
  71. int n_ffn_gate = 0;
  72. int n_ffn_up = 0;
  73. int i_attention_wv = 0;
  74. int i_ffn_down = 0;
  75. int i_ffn_gate = 0;
  76. int i_ffn_up = 0;
  77. int n_k_quantized = 0;
  78. int n_fallback = 0;
  79. bool has_imatrix = false;
  80. // used to figure out if a model shares tok_embd with the output weight
  81. bool has_output = false;
  82. quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
  83. : model(model)
  84. , params(params)
  85. {}
  86. };
  87. static void llama_tensor_dequantize_impl(
  88. ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
  89. const size_t nelements, const int nthread
  90. ) {
  91. if (output.size() < nelements) {
  92. output.resize(nelements);
  93. }
  94. float * f32_output = (float *) output.data();
  95. const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
  96. if (ggml_is_quantized(tensor->type)) {
  97. if (qtype->to_float == NULL) {
  98. throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
  99. }
  100. } else if (tensor->type != GGML_TYPE_F16 &&
  101. tensor->type != GGML_TYPE_BF16) {
  102. throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
  103. }
  104. if (nthread < 2) {
  105. if (tensor->type == GGML_TYPE_F16) {
  106. ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
  107. } else if (tensor->type == GGML_TYPE_BF16) {
  108. ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
  109. } else if (ggml_is_quantized(tensor->type)) {
  110. qtype->to_float(tensor->data, f32_output, nelements);
  111. } else {
  112. GGML_ABORT("fatal error"); // unreachable
  113. }
  114. return;
  115. }
  116. size_t block_size;
  117. if (tensor->type == GGML_TYPE_F16 ||
  118. tensor->type == GGML_TYPE_BF16) {
  119. block_size = 1;
  120. } else {
  121. block_size = (size_t)ggml_blck_size(tensor->type);
  122. }
  123. size_t block_size_bytes = ggml_type_size(tensor->type);
  124. GGML_ASSERT(nelements % block_size == 0);
  125. size_t nblocks = nelements / block_size;
  126. size_t blocks_per_thread = nblocks / nthread;
  127. size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
  128. size_t in_buff_offs = 0;
  129. size_t out_buff_offs = 0;
  130. for (int tnum = 0; tnum < nthread; tnum++) {
  131. size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
  132. size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
  133. size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
  134. auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
  135. if (typ == GGML_TYPE_F16) {
  136. ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
  137. } else if (typ == GGML_TYPE_BF16) {
  138. ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
  139. } else {
  140. qtype->to_float(inbuf, outbuf, nels);
  141. }
  142. };
  143. workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
  144. in_buff_offs += thr_block_bytes;
  145. out_buff_offs += thr_elems;
  146. }
  147. for (auto & w : workers) { w.join(); }
  148. workers.clear();
  149. }
  150. static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
  151. const std::string name = ggml_get_name(tensor);
  152. // TODO: avoid hardcoded tensor names - use the TN_* constants
  153. const llm_arch arch = qs.model.arch;
  154. const auto tn = LLM_TN(arch);
  155. auto use_more_bits = [](int i_layer, int n_layers) -> bool {
  156. return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
  157. };
  158. const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
  159. auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
  160. if (n_expert > 1) {
  161. // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
  162. // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
  163. // for getting the current layer as I initially thought, and we need to resort to parsing the
  164. // tensor name.
  165. if (sscanf(name, "blk.%d.", &i_layer) != 1) {
  166. throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
  167. }
  168. if (i_layer < 0 || i_layer >= n_layer) {
  169. throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
  170. }
  171. }
  172. return std::make_pair(i_layer, n_layer);
  173. };
  174. // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
  175. // with the quantization of the output tensor
  176. if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
  177. if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
  178. new_type = qs.params->output_tensor_type;
  179. } else {
  180. const int64_t nx = tensor->ne[0];
  181. const int64_t qk_k = ggml_blck_size(new_type);
  182. if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
  183. new_type = GGML_TYPE_Q8_0;
  184. }
  185. else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
  186. new_type = GGML_TYPE_Q8_0;
  187. }
  188. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  189. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
  190. ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  191. new_type = GGML_TYPE_Q5_K;
  192. }
  193. else if (new_type != GGML_TYPE_Q8_0) {
  194. new_type = GGML_TYPE_Q6_K;
  195. }
  196. }
  197. } else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
  198. // MoE tensors -> MXFP4
  199. // other tensors -> Q8_0
  200. if (tensor->ne[2] > 1) {
  201. new_type = GGML_TYPE_MXFP4;
  202. } else {
  203. new_type = GGML_TYPE_Q8_0;
  204. }
  205. } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
  206. if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
  207. new_type = qs.params->token_embedding_type;
  208. } else {
  209. if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
  210. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  211. new_type = GGML_TYPE_Q2_K;
  212. }
  213. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
  214. new_type = GGML_TYPE_IQ3_S;
  215. }
  216. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  217. new_type = GGML_TYPE_IQ3_S;
  218. }
  219. else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
  220. new_type = GGML_TYPE_Q4_K;
  221. }
  222. }
  223. } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
  224. ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  225. if (name.find("attn_v.weight") != std::string::npos) {
  226. if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
  227. else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  228. ++qs.i_attention_wv;
  229. }
  230. else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
  231. new_type = GGML_TYPE_Q4_K;
  232. }
  233. else if (name.find("ffn_down") != std::string::npos) {
  234. if (qs.i_ffn_down < qs.n_ffn_down/8) {
  235. new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  236. }
  237. ++qs.i_ffn_down;
  238. }
  239. else if (name.find("attn_output.weight") != std::string::npos) {
  240. if (qs.model.hparams.n_expert == 8) {
  241. new_type = GGML_TYPE_Q5_K;
  242. } else {
  243. if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
  244. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
  245. }
  246. }
  247. } else if (name.find("attn_v.weight") != std::string::npos) {
  248. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
  249. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  250. }
  251. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
  252. new_type = GGML_TYPE_Q4_K;
  253. }
  254. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  255. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
  256. }
  257. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
  258. new_type = GGML_TYPE_Q4_K;
  259. }
  260. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  261. new_type = GGML_TYPE_Q4_K;
  262. }
  263. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  264. new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  265. }
  266. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
  267. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
  268. new_type = GGML_TYPE_Q5_K;
  269. }
  270. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
  271. use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
  272. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
  273. if (qs.model.type == LLM_TYPE_70B) {
  274. // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
  275. // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
  276. // nearly negligible increase in model size by quantizing this tensor with more bits:
  277. if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
  278. }
  279. if (qs.model.hparams.n_expert == 8) {
  280. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  281. // TODO: explore better strategies
  282. new_type = GGML_TYPE_Q8_0;
  283. }
  284. ++qs.i_attention_wv;
  285. } else if (name.find("attn_k.weight") != std::string::npos) {
  286. if (qs.model.hparams.n_expert == 8) {
  287. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  288. // TODO: explore better strategies
  289. new_type = GGML_TYPE_Q8_0;
  290. }
  291. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  292. new_type = GGML_TYPE_IQ3_XXS;
  293. }
  294. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  295. new_type = GGML_TYPE_IQ2_S;
  296. }
  297. } else if (name.find("attn_q.weight") != std::string::npos) {
  298. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  299. new_type = GGML_TYPE_IQ3_XXS;
  300. }
  301. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  302. new_type = GGML_TYPE_IQ2_S;
  303. }
  304. } else if (name.find("ffn_down") != std::string::npos) {
  305. auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
  306. int i_layer = info.first, n_layer = info.second;
  307. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  308. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
  309. if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
  310. }
  311. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
  312. new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  313. }
  314. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  315. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
  316. : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
  317. : GGML_TYPE_Q3_K;
  318. }
  319. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
  320. (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
  321. new_type = GGML_TYPE_Q4_K;
  322. }
  323. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
  324. new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
  325. }
  326. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
  327. if (arch == LLM_ARCH_FALCON) {
  328. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
  329. use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  330. } else {
  331. if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  332. }
  333. }
  334. else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
  335. new_type = GGML_TYPE_Q5_K;
  336. }
  337. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  338. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
  339. new_type = GGML_TYPE_Q5_K;
  340. }
  341. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
  342. && qs.has_imatrix && i_layer < n_layer/8) {
  343. // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
  344. // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
  345. // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
  346. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
  347. }
  348. ++qs.i_ffn_down;
  349. } else if (name.find("attn_output.weight") != std::string::npos) {
  350. if (arch != LLM_ARCH_FALCON) {
  351. if (qs.model.hparams.n_expert == 8) {
  352. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  353. ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
  354. ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
  355. ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
  356. new_type = GGML_TYPE_Q5_K;
  357. }
  358. } else {
  359. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
  360. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
  361. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
  362. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
  363. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
  364. }
  365. } else {
  366. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
  367. }
  368. }
  369. else if (name.find("attn_qkv.weight") != std::string::npos) {
  370. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  371. new_type = GGML_TYPE_Q4_K;
  372. }
  373. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
  374. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
  375. }
  376. else if (name.find("ffn_gate") != std::string::npos) {
  377. auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
  378. int i_layer = info.first, n_layer = info.second;
  379. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  380. new_type = GGML_TYPE_IQ3_XXS;
  381. }
  382. ++qs.i_ffn_gate;
  383. }
  384. else if (name.find("ffn_up") != std::string::npos) {
  385. auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
  386. int i_layer = info.first, n_layer = info.second;
  387. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  388. new_type = GGML_TYPE_IQ3_XXS;
  389. }
  390. ++qs.i_ffn_up;
  391. }
  392. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  393. //}
  394. // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
  395. //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
  396. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  397. //}
  398. // This can be used to reduce the size of the Q5_K_S model.
  399. // The associated PPL increase is fully in line with the size reduction
  400. //else {
  401. // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
  402. //}
  403. bool convert_incompatible_tensor = false;
  404. {
  405. const int64_t nx = tensor->ne[0];
  406. const int64_t ny = tensor->ne[1];
  407. const int64_t qk_k = ggml_blck_size(new_type);
  408. if (nx % qk_k != 0) {
  409. LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
  410. convert_incompatible_tensor = true;
  411. } else {
  412. ++qs.n_k_quantized;
  413. }
  414. }
  415. if (convert_incompatible_tensor) {
  416. switch (new_type) {
  417. case GGML_TYPE_TQ1_0:
  418. case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
  419. case GGML_TYPE_IQ2_XXS:
  420. case GGML_TYPE_IQ2_XS:
  421. case GGML_TYPE_IQ2_S:
  422. case GGML_TYPE_IQ3_XXS:
  423. case GGML_TYPE_IQ3_S:
  424. case GGML_TYPE_IQ1_S:
  425. case GGML_TYPE_IQ1_M:
  426. case GGML_TYPE_Q2_K:
  427. case GGML_TYPE_Q3_K:
  428. case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
  429. case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
  430. case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
  431. case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
  432. default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
  433. }
  434. if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
  435. new_type = GGML_TYPE_F16;
  436. }
  437. LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
  438. ++qs.n_fallback;
  439. }
  440. return new_type;
  441. }
  442. static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
  443. if (nthread < 2) {
  444. // single-thread
  445. size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
  446. if (!ggml_validate_row_data(new_type, new_data, new_size)) {
  447. throw std::runtime_error("quantized data validation failed");
  448. }
  449. return new_size;
  450. }
  451. std::mutex mutex;
  452. int64_t counter = 0;
  453. size_t new_size = 0;
  454. bool valid = true;
  455. auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
  456. nrows, n_per_row, imatrix]() {
  457. const int64_t nrows_per_chunk = chunk_size / n_per_row;
  458. size_t local_size = 0;
  459. while (true) {
  460. std::unique_lock<std::mutex> lock(mutex);
  461. int64_t first_row = counter; counter += nrows_per_chunk;
  462. if (first_row >= nrows) {
  463. if (local_size > 0) {
  464. new_size += local_size;
  465. }
  466. break;
  467. }
  468. lock.unlock();
  469. const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
  470. size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
  471. local_size += this_size;
  472. // validate the quantized data
  473. const size_t row_size = ggml_row_size(new_type, n_per_row);
  474. void * this_data = (char *) new_data + first_row * row_size;
  475. if (!ggml_validate_row_data(new_type, this_data, this_size)) {
  476. std::unique_lock<std::mutex> lock(mutex);
  477. valid = false;
  478. break;
  479. }
  480. }
  481. };
  482. for (int it = 0; it < nthread - 1; ++it) {
  483. workers.emplace_back(compute);
  484. }
  485. compute();
  486. for (auto & w : workers) { w.join(); }
  487. workers.clear();
  488. if (!valid) {
  489. throw std::runtime_error("quantized data validation failed");
  490. }
  491. return new_size;
  492. }
  493. static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
  494. ggml_type default_type;
  495. llama_ftype ftype = params->ftype;
  496. switch (params->ftype) {
  497. case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
  498. case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
  499. case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
  500. case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
  501. case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
  502. case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
  503. case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
  504. case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
  505. case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
  506. // K-quants
  507. case LLAMA_FTYPE_MOSTLY_Q2_K_S:
  508. case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
  509. case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
  510. case LLAMA_FTYPE_MOSTLY_Q3_K_S:
  511. case LLAMA_FTYPE_MOSTLY_Q3_K_M:
  512. case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
  513. case LLAMA_FTYPE_MOSTLY_Q4_K_S:
  514. case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
  515. case LLAMA_FTYPE_MOSTLY_Q5_K_S:
  516. case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
  517. case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
  518. case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
  519. case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
  520. case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
  521. case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
  522. case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
  523. case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
  524. case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
  525. case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
  526. case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
  527. case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
  528. case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
  529. case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
  530. case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
  531. default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
  532. }
  533. int nthread = params->nthread;
  534. if (nthread <= 0) {
  535. nthread = std::thread::hardware_concurrency();
  536. }
  537. // mmap consistently increases speed on Linux, and also increases speed on Windows with
  538. // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
  539. #if defined(__linux__) || defined(_WIN32)
  540. constexpr bool use_mmap = true;
  541. #else
  542. constexpr bool use_mmap = false;
  543. #endif
  544. llama_model_kv_override * kv_overrides = nullptr;
  545. if (params->kv_overrides) {
  546. auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
  547. kv_overrides = v->data();
  548. }
  549. std::vector<std::string> splits = {};
  550. llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
  551. ml.init_mappings(false); // no prefetching
  552. llama_model model(llama_model_default_params());
  553. model.load_arch (ml);
  554. model.load_hparams(ml);
  555. model.load_stats (ml);
  556. quantize_state_impl qs(model, params);
  557. if (params->only_copy) {
  558. ftype = ml.ftype;
  559. }
  560. const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
  561. if (params->imatrix) {
  562. imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
  563. if (imatrix_data) {
  564. LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
  565. qs.has_imatrix = true;
  566. // check imatrix for nans or infs
  567. for (const auto & kv : *imatrix_data) {
  568. for (float f : kv.second) {
  569. if (!std::isfinite(f)) {
  570. throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
  571. }
  572. }
  573. }
  574. }
  575. }
  576. const size_t align = GGUF_DEFAULT_ALIGNMENT;
  577. gguf_context_ptr ctx_out { gguf_init_empty() };
  578. std::vector<int> prune_list = {};
  579. if (params->prune_layers) {
  580. prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
  581. }
  582. // copy the KV pairs from the input file
  583. gguf_set_kv (ctx_out.get(), ml.meta.get());
  584. gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
  585. gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
  586. // Remove split metadata
  587. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
  588. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
  589. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
  590. if (params->kv_overrides) {
  591. const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
  592. for (const auto & o : overrides) {
  593. if (o.key[0] == 0) break;
  594. if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
  595. gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
  596. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
  597. // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
  598. gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
  599. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
  600. gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
  601. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
  602. gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
  603. } else {
  604. LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
  605. }
  606. }
  607. }
  608. std::map<int, std::string> mapped;
  609. int blk_id = 0;
  610. int pruned_attention_w = 0;
  611. // make a list of weights
  612. std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
  613. tensors.reserve(ml.weights_map.size());
  614. for (const auto & it : ml.weights_map) {
  615. const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
  616. if (remapped_name.empty()) {
  617. if (it.first.find("attn_v.weight") != std::string::npos ||
  618. it.first.find("attn_qkv.weight") != std::string::npos ||
  619. it.first.find("attn_kv_b.weight") != std::string::npos) {
  620. pruned_attention_w++;
  621. }
  622. LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
  623. continue;
  624. } else if (remapped_name != it.first) {
  625. ggml_set_name(it.second.tensor, remapped_name.c_str());
  626. LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
  627. }
  628. tensors.push_back(&it.second);
  629. }
  630. if (!prune_list.empty()) {
  631. gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
  632. }
  633. // keep_split requires that the weights are sorted by split index
  634. if (params->keep_split) {
  635. std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
  636. if (a->idx == b->idx) {
  637. return a->offs < b->offs;
  638. }
  639. return a->idx < b->idx;
  640. });
  641. }
  642. for (const auto * it : tensors) {
  643. const struct ggml_tensor * tensor = it->tensor;
  644. const std::string name = ggml_get_name(tensor);
  645. // TODO: avoid hardcoded tensor names - use the TN_* constants
  646. if (name.find("attn_v.weight") != std::string::npos ||
  647. name.find("attn_qkv.weight") != std::string::npos ||
  648. name.find("attn_kv_b.weight")!= std::string::npos) {
  649. ++qs.n_attention_wv;
  650. } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
  651. qs.has_output = true;
  652. }
  653. }
  654. qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
  655. // sanity checks for models that have attention layers
  656. if (qs.n_attention_wv != 0)
  657. {
  658. const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
  659. // attention layers have a non-zero number of kv heads
  660. int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
  661. if (llama_model_has_encoder(&model)) {
  662. // now n_attn_layer is the number of attention layers in the encoder
  663. // for each decoder block, there are 2 attention layers
  664. n_attn_layer += 2 * model.hparams.dec_n_layer;
  665. }
  666. GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
  667. }
  668. size_t total_size_org = 0;
  669. size_t total_size_new = 0;
  670. std::vector<std::thread> workers;
  671. workers.reserve(nthread);
  672. int idx = 0;
  673. std::vector<no_init<uint8_t>> read_data;
  674. std::vector<no_init<uint8_t>> work;
  675. std::vector<no_init<float>> f32_conv_buf;
  676. uint16_t n_split = 1;
  677. // Assume split index is continuous
  678. if (params->keep_split) {
  679. for (const auto * it : tensors) {
  680. n_split = std::max(uint16_t(it->idx + 1), n_split);
  681. }
  682. }
  683. std::vector<gguf_context_ptr> ctx_outs(n_split);
  684. ctx_outs[0] = std::move(ctx_out);
  685. // populate the original tensors so we get an initial meta data
  686. for (const auto * it : tensors) {
  687. uint16_t i_split = params->keep_split ? it->idx : 0;
  688. ggml_tensor * tensor = it->tensor;
  689. if (!ctx_outs[i_split]) {
  690. ctx_outs[i_split].reset(gguf_init_empty());
  691. }
  692. gguf_add_tensor(ctx_outs[i_split].get(), tensor);
  693. }
  694. // Set split info if needed
  695. if (n_split > 1) {
  696. for (size_t i = 0; i < ctx_outs.size(); ++i) {
  697. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
  698. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
  699. gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
  700. }
  701. }
  702. int cur_split = -1;
  703. std::ofstream fout;
  704. auto close_ofstream = [&]() {
  705. // Write metadata and close file handler
  706. if (fout.is_open()) {
  707. fout.seekp(0);
  708. std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
  709. gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
  710. fout.write((const char *) data.data(), data.size());
  711. fout.close();
  712. }
  713. };
  714. auto new_ofstream = [&](int index) {
  715. cur_split = index;
  716. GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
  717. std::string fname = fname_out;
  718. if (params->keep_split) {
  719. std::vector<char> split_path(llama_path_max(), 0);
  720. llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
  721. fname = std::string(split_path.data());
  722. }
  723. fout = std::ofstream(fname, std::ios::binary);
  724. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  725. const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
  726. // placeholder for the meta data
  727. ::zeros(fout, meta_size);
  728. };
  729. const auto tn = LLM_TN(model.arch);
  730. new_ofstream(0);
  731. for (const auto * it : tensors) {
  732. const auto & weight = *it;
  733. ggml_tensor * tensor = weight.tensor;
  734. if (weight.idx != cur_split && params->keep_split) {
  735. close_ofstream();
  736. new_ofstream(weight.idx);
  737. }
  738. const std::string name = ggml_get_name(tensor);
  739. if (!ml.use_mmap) {
  740. if (read_data.size() < ggml_nbytes(tensor)) {
  741. read_data.resize(ggml_nbytes(tensor));
  742. }
  743. tensor->data = read_data.data();
  744. }
  745. ml.load_data_for(tensor);
  746. LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
  747. ++idx, ml.n_tensors,
  748. ggml_get_name(tensor),
  749. llama_format_tensor_shape(tensor).c_str(),
  750. ggml_type_name(tensor->type));
  751. // This used to be a regex, but <regex> has an extreme cost to compile times.
  752. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
  753. // quantize only 2D and 3D tensors (experts)
  754. quantize &= (ggml_n_dims(tensor) >= 2);
  755. // do not quantize norm tensors
  756. quantize &= name.find("_norm.weight") == std::string::npos;
  757. quantize &= params->quantize_output_tensor || name != "output.weight";
  758. quantize &= !params->only_copy;
  759. // do not quantize expert gating tensors
  760. // NOTE: can't use LLM_TN here because the layer number is not known
  761. quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
  762. // these are very small (e.g. 4x4)
  763. quantize &= name.find("altup") == std::string::npos;
  764. quantize &= name.find("laurel") == std::string::npos;
  765. // these are not too big so keep them as it is
  766. quantize &= name.find("per_layer_model_proj") == std::string::npos;
  767. // do not quantize positional embeddings and token types (BERT)
  768. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
  769. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
  770. // do not quantize Mamba's small yet 2D weights
  771. // NOTE: can't use LLM_TN here because the layer number is not known
  772. quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
  773. quantize &= name.find("shortconv.conv.weight") == std::string::npos;
  774. // do not quantize RWKV's small yet 2D weights
  775. quantize &= name.find("time_mix_first.weight") == std::string::npos;
  776. quantize &= name.find("time_mix_w0.weight") == std::string::npos;
  777. quantize &= name.find("time_mix_w1.weight") == std::string::npos;
  778. quantize &= name.find("time_mix_w2.weight") == std::string::npos;
  779. quantize &= name.find("time_mix_v0.weight") == std::string::npos;
  780. quantize &= name.find("time_mix_v1.weight") == std::string::npos;
  781. quantize &= name.find("time_mix_v2.weight") == std::string::npos;
  782. quantize &= name.find("time_mix_a0.weight") == std::string::npos;
  783. quantize &= name.find("time_mix_a1.weight") == std::string::npos;
  784. quantize &= name.find("time_mix_a2.weight") == std::string::npos;
  785. quantize &= name.find("time_mix_g1.weight") == std::string::npos;
  786. quantize &= name.find("time_mix_g2.weight") == std::string::npos;
  787. quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
  788. quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
  789. quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
  790. // do not quantize relative position bias (T5)
  791. quantize &= name.find("attn_rel_b.weight") == std::string::npos;
  792. ggml_type new_type;
  793. void * new_data;
  794. size_t new_size;
  795. if (quantize) {
  796. new_type = default_type;
  797. // get more optimal quantization type based on the tensor shape, layer, etc.
  798. if (!params->pure && ggml_is_quantized(default_type)) {
  799. int fallback = qs.n_fallback;
  800. new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
  801. // unless the user specifies a type, and the tensor geometry will not require fallback quantisation
  802. if (params->tensor_types && qs.n_fallback - fallback == 0) {
  803. const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
  804. const std::string tensor_name(tensor->name);
  805. for (const auto & [tname, qtype] : tensor_types) {
  806. if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
  807. if (qtype != new_type) {
  808. LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
  809. new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
  810. }
  811. }
  812. }
  813. }
  814. }
  815. if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
  816. new_type = params->token_embedding_type;
  817. }
  818. if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
  819. new_type = params->output_tensor_type;
  820. }
  821. // If we've decided to quantize to the same type the tensor is already
  822. // in then there's nothing to do.
  823. quantize = tensor->type != new_type;
  824. }
  825. if (!quantize) {
  826. new_type = tensor->type;
  827. new_data = tensor->data;
  828. new_size = ggml_nbytes(tensor);
  829. LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0);
  830. } else {
  831. const int64_t nelements = ggml_nelements(tensor);
  832. const float * imatrix = nullptr;
  833. if (imatrix_data) {
  834. auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
  835. if (it == imatrix_data->end()) {
  836. LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
  837. } else {
  838. if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
  839. imatrix = it->second.data();
  840. } else {
  841. LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
  842. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
  843. // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
  844. // this is a significant error and it may be good idea to abort the process if this happens,
  845. // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
  846. // tok_embd should be ignored in this case, since it always causes this warning
  847. if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
  848. throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
  849. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
  850. }
  851. }
  852. }
  853. }
  854. if ((new_type == GGML_TYPE_IQ2_XXS ||
  855. new_type == GGML_TYPE_IQ2_XS ||
  856. new_type == GGML_TYPE_IQ2_S ||
  857. new_type == GGML_TYPE_IQ1_S ||
  858. (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
  859. (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
  860. LLAMA_LOG_ERROR("\n\n============================================================\n");
  861. LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
  862. LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
  863. LLAMA_LOG_ERROR("============================================================\n\n");
  864. throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
  865. }
  866. float * f32_data;
  867. if (tensor->type == GGML_TYPE_F32) {
  868. f32_data = (float *) tensor->data;
  869. } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
  870. throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
  871. } else {
  872. llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
  873. f32_data = (float *) f32_conv_buf.data();
  874. }
  875. LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
  876. fflush(stdout);
  877. if (work.size() < (size_t)nelements * 4) {
  878. work.resize(nelements * 4); // upper bound on size
  879. }
  880. new_data = work.data();
  881. const int64_t n_per_row = tensor->ne[0];
  882. const int64_t nrows = tensor->ne[1];
  883. static const int64_t min_chunk_size = 32 * 512;
  884. const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
  885. const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
  886. const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
  887. const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
  888. // quantize each expert separately since they have different importance matrices
  889. new_size = 0;
  890. for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
  891. const float * f32_data_03 = f32_data + i03 * nelements_matrix;
  892. void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
  893. const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
  894. new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
  895. // TODO: temporary sanity check that the F16 -> MXFP4 is lossless
  896. #if 0
  897. if (new_type == GGML_TYPE_MXFP4) {
  898. auto * x = f32_data_03;
  899. //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
  900. std::vector<float> deq(nrows*n_per_row);
  901. const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
  902. qtype->to_float(new_data_03, deq.data(), deq.size());
  903. double err = 0.0f;
  904. for (int i = 0; i < (int) deq.size(); ++i) {
  905. err += fabsf(deq[i] - x[i]);
  906. //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
  907. if (deq[i] != x[i]) {
  908. LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
  909. }
  910. }
  911. //LLAMA_LOG_INFO("err = %f\n", err);
  912. GGML_ASSERT(err == 0.00000);
  913. }
  914. #endif
  915. }
  916. LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
  917. }
  918. total_size_org += ggml_nbytes(tensor);
  919. total_size_new += new_size;
  920. // update the gguf meta data as we go
  921. gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
  922. GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
  923. gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
  924. // write tensor data + padding
  925. fout.write((const char *) new_data, new_size);
  926. zeros(fout, GGML_PAD(new_size, align) - new_size);
  927. }
  928. close_ofstream();
  929. LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0);
  930. LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0);
  931. if (qs.n_fallback > 0) {
  932. LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
  933. __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
  934. }
  935. }
  936. //
  937. // interface implementation
  938. //
  939. llama_model_quantize_params llama_model_quantize_default_params() {
  940. llama_model_quantize_params result = {
  941. /*.nthread =*/ 0,
  942. /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
  943. /*.output_tensor_type =*/ GGML_TYPE_COUNT,
  944. /*.token_embedding_type =*/ GGML_TYPE_COUNT,
  945. /*.allow_requantize =*/ false,
  946. /*.quantize_output_tensor =*/ true,
  947. /*.only_copy =*/ false,
  948. /*.pure =*/ false,
  949. /*.keep_split =*/ false,
  950. /*.imatrix =*/ nullptr,
  951. /*.kv_overrides =*/ nullptr,
  952. /*.tensor_type =*/ nullptr,
  953. /*.prune_layers =*/ nullptr
  954. };
  955. return result;
  956. }
  957. uint32_t llama_model_quantize(
  958. const char * fname_inp,
  959. const char * fname_out,
  960. const llama_model_quantize_params * params) {
  961. try {
  962. llama_model_quantize_impl(fname_inp, fname_out, params);
  963. } catch (const std::exception & err) {
  964. LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
  965. return 1;
  966. }
  967. return 0;
  968. }