llama-quant.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. #include "llama-quant.h"
  2. #include "llama-impl.h"
  3. #include "llama-model.h"
  4. #include "llama-model-loader.h"
  5. #include <algorithm>
  6. #include <cmath>
  7. #include <cstring>
  8. #include <cinttypes>
  9. #include <fstream>
  10. #include <mutex>
  11. #include <regex>
  12. #include <thread>
  13. #include <unordered_map>
  14. // Quantization types. Changes to this struct must be replicated in quantize.cpp
  15. struct tensor_quantization {
  16. std::string name;
  17. ggml_type quant = GGML_TYPE_COUNT;
  18. };
  19. static void zeros(std::ofstream & file, size_t n) {
  20. char zero = 0;
  21. for (size_t i = 0; i < n; ++i) {
  22. file.write(&zero, 1);
  23. }
  24. }
  25. static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
  26. if (prune.empty()) {
  27. return orig_name;
  28. }
  29. static const std::regex pattern(R"(blk\.(\d+)\.)");
  30. if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  31. const int blk = std::stoi(match[1]);
  32. std::string new_name = orig_name;
  33. if (mapped.count(blk)) {
  34. // Already mapped, do nothing
  35. } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
  36. mapped[blk] = "";
  37. } else if (blk < prune.front()) {
  38. mapped[blk] = std::to_string(blk);
  39. next_id = blk + 1;
  40. } else {
  41. mapped[blk] = std::to_string(next_id);
  42. ++next_id;
  43. }
  44. return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
  45. }
  46. return orig_name;
  47. }
  48. static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
  49. if (mapped.empty()) {
  50. return orig_name;
  51. }
  52. static const std::regex pattern(R"(blk\.(\d+)\.)");
  53. if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
  54. const std::string blk(match[1]);
  55. std::string new_name = orig_name;
  56. for (const auto & p : mapped) {
  57. if (p.second == blk) {
  58. LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
  59. return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
  60. }
  61. }
  62. GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
  63. }
  64. return orig_name;
  65. }
  66. struct quantize_state_impl {
  67. const llama_model & model;
  68. const llama_model_quantize_params * params;
  69. int n_attention_wv = 0;
  70. int n_ffn_down = 0;
  71. int n_ffn_gate = 0;
  72. int n_ffn_up = 0;
  73. int i_attention_wv = 0;
  74. int i_ffn_down = 0;
  75. int i_ffn_gate = 0;
  76. int i_ffn_up = 0;
  77. int n_k_quantized = 0;
  78. int n_fallback = 0;
  79. bool has_imatrix = false;
  80. // used to figure out if a model shares tok_embd with the output weight
  81. bool has_output = false;
  82. quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
  83. : model(model)
  84. , params(params)
  85. {}
  86. };
  87. static void llama_tensor_dequantize_impl(
  88. ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
  89. const size_t nelements, const int nthread
  90. ) {
  91. if (output.size() < nelements) {
  92. output.resize(nelements);
  93. }
  94. float * f32_output = (float *) output.data();
  95. const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
  96. if (ggml_is_quantized(tensor->type)) {
  97. if (qtype->to_float == NULL) {
  98. throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
  99. }
  100. } else if (tensor->type != GGML_TYPE_F16 &&
  101. tensor->type != GGML_TYPE_BF16) {
  102. throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
  103. }
  104. if (nthread < 2) {
  105. if (tensor->type == GGML_TYPE_F16) {
  106. ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
  107. } else if (tensor->type == GGML_TYPE_BF16) {
  108. ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
  109. } else if (ggml_is_quantized(tensor->type)) {
  110. qtype->to_float(tensor->data, f32_output, nelements);
  111. } else {
  112. GGML_ABORT("fatal error"); // unreachable
  113. }
  114. return;
  115. }
  116. size_t block_size;
  117. if (tensor->type == GGML_TYPE_F16 ||
  118. tensor->type == GGML_TYPE_BF16) {
  119. block_size = 1;
  120. } else {
  121. block_size = (size_t)ggml_blck_size(tensor->type);
  122. }
  123. size_t block_size_bytes = ggml_type_size(tensor->type);
  124. GGML_ASSERT(nelements % block_size == 0);
  125. size_t nblocks = nelements / block_size;
  126. size_t blocks_per_thread = nblocks / nthread;
  127. size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
  128. size_t in_buff_offs = 0;
  129. size_t out_buff_offs = 0;
  130. for (int tnum = 0; tnum < nthread; tnum++) {
  131. size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
  132. size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
  133. size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
  134. auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
  135. if (typ == GGML_TYPE_F16) {
  136. ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
  137. } else if (typ == GGML_TYPE_BF16) {
  138. ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
  139. } else {
  140. qtype->to_float(inbuf, outbuf, nels);
  141. }
  142. };
  143. workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
  144. in_buff_offs += thr_block_bytes;
  145. out_buff_offs += thr_elems;
  146. }
  147. for (auto & w : workers) { w.join(); }
  148. workers.clear();
  149. }
  150. static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
  151. const std::string name = ggml_get_name(tensor);
  152. // TODO: avoid hardcoded tensor names - use the TN_* constants
  153. const llm_arch arch = qs.model.arch;
  154. const auto tn = LLM_TN(arch);
  155. auto use_more_bits = [](int i_layer, int n_layers) -> bool {
  156. return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
  157. };
  158. const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
  159. auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
  160. if (n_expert > 1) {
  161. // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
  162. // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
  163. // for getting the current layer as I initially thought, and we need to resort to parsing the
  164. // tensor name.
  165. if (sscanf(name, "blk.%d.", &i_layer) != 1) {
  166. throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
  167. }
  168. if (i_layer < 0 || i_layer >= n_layer) {
  169. throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
  170. }
  171. }
  172. return std::make_pair(i_layer, n_layer);
  173. };
  174. // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
  175. // with the quantization of the output tensor
  176. if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
  177. if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
  178. new_type = qs.params->output_tensor_type;
  179. } else {
  180. const int64_t nx = tensor->ne[0];
  181. const int64_t qk_k = ggml_blck_size(new_type);
  182. if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
  183. new_type = GGML_TYPE_Q8_0;
  184. }
  185. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  186. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
  187. ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  188. new_type = GGML_TYPE_Q5_K;
  189. }
  190. else if (new_type != GGML_TYPE_Q8_0) {
  191. new_type = GGML_TYPE_Q6_K;
  192. }
  193. }
  194. } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
  195. if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
  196. new_type = qs.params->token_embedding_type;
  197. } else {
  198. if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
  199. ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  200. new_type = GGML_TYPE_Q2_K;
  201. }
  202. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
  203. new_type = GGML_TYPE_IQ3_S;
  204. }
  205. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  206. new_type = GGML_TYPE_IQ3_S;
  207. }
  208. else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
  209. new_type = GGML_TYPE_Q4_K;
  210. }
  211. }
  212. } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
  213. ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
  214. if (name.find("attn_v.weight") != std::string::npos) {
  215. if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
  216. else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  217. ++qs.i_attention_wv;
  218. }
  219. else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
  220. new_type = GGML_TYPE_Q4_K;
  221. }
  222. else if (name.find("ffn_down") != std::string::npos) {
  223. if (qs.i_ffn_down < qs.n_ffn_down/8) {
  224. new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
  225. }
  226. ++qs.i_ffn_down;
  227. }
  228. else if (name.find("attn_output.weight") != std::string::npos) {
  229. if (qs.model.hparams.n_expert == 8) {
  230. new_type = GGML_TYPE_Q5_K;
  231. } else {
  232. if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
  233. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
  234. }
  235. }
  236. } else if (name.find("attn_v.weight") != std::string::npos) {
  237. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
  238. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  239. }
  240. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
  241. new_type = GGML_TYPE_Q4_K;
  242. }
  243. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  244. new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
  245. }
  246. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
  247. new_type = GGML_TYPE_Q4_K;
  248. }
  249. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  250. new_type = GGML_TYPE_Q4_K;
  251. }
  252. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  253. new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  254. }
  255. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
  256. else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
  257. new_type = GGML_TYPE_Q5_K;
  258. }
  259. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
  260. use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
  261. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
  262. if (qs.model.type == LLM_TYPE_70B) {
  263. // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
  264. // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
  265. // nearly negligible increase in model size by quantizing this tensor with more bits:
  266. if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
  267. }
  268. if (qs.model.hparams.n_expert == 8) {
  269. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  270. // TODO: explore better strategies
  271. new_type = GGML_TYPE_Q8_0;
  272. }
  273. ++qs.i_attention_wv;
  274. } else if (name.find("attn_k.weight") != std::string::npos) {
  275. if (qs.model.hparams.n_expert == 8) {
  276. // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
  277. // TODO: explore better strategies
  278. new_type = GGML_TYPE_Q8_0;
  279. }
  280. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  281. new_type = GGML_TYPE_IQ3_XXS;
  282. }
  283. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  284. new_type = GGML_TYPE_IQ2_S;
  285. }
  286. } else if (name.find("attn_q.weight") != std::string::npos) {
  287. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
  288. new_type = GGML_TYPE_IQ3_XXS;
  289. }
  290. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
  291. new_type = GGML_TYPE_IQ2_S;
  292. }
  293. } else if (name.find("ffn_down") != std::string::npos) {
  294. auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
  295. int i_layer = info.first, n_layer = info.second;
  296. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  297. else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
  298. if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
  299. }
  300. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
  301. new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
  302. }
  303. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
  304. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
  305. : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
  306. : GGML_TYPE_Q3_K;
  307. }
  308. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
  309. (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
  310. new_type = GGML_TYPE_Q4_K;
  311. }
  312. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
  313. new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
  314. }
  315. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
  316. if (arch == LLM_ARCH_FALCON) {
  317. new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
  318. use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
  319. } else {
  320. if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  321. }
  322. }
  323. else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
  324. new_type = GGML_TYPE_Q5_K;
  325. }
  326. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
  327. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
  328. new_type = GGML_TYPE_Q5_K;
  329. }
  330. else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
  331. && qs.has_imatrix && i_layer < n_layer/8) {
  332. // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
  333. // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
  334. // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
  335. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
  336. }
  337. ++qs.i_ffn_down;
  338. } else if (name.find("attn_output.weight") != std::string::npos) {
  339. if (arch != LLM_ARCH_FALCON) {
  340. if (qs.model.hparams.n_expert == 8) {
  341. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
  342. ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
  343. ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
  344. ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
  345. new_type = GGML_TYPE_Q5_K;
  346. }
  347. } else {
  348. if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
  349. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
  350. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
  351. else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
  352. else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
  353. }
  354. } else {
  355. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
  356. }
  357. }
  358. else if (name.find("attn_qkv.weight") != std::string::npos) {
  359. if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
  360. new_type = GGML_TYPE_Q4_K;
  361. }
  362. else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
  363. else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
  364. }
  365. else if (name.find("ffn_gate") != std::string::npos) {
  366. auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
  367. int i_layer = info.first, n_layer = info.second;
  368. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  369. new_type = GGML_TYPE_IQ3_XXS;
  370. }
  371. ++qs.i_ffn_gate;
  372. }
  373. else if (name.find("ffn_up") != std::string::npos) {
  374. auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
  375. int i_layer = info.first, n_layer = info.second;
  376. if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
  377. new_type = GGML_TYPE_IQ3_XXS;
  378. }
  379. ++qs.i_ffn_up;
  380. }
  381. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  382. //}
  383. // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
  384. //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
  385. // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
  386. //}
  387. // This can be used to reduce the size of the Q5_K_S model.
  388. // The associated PPL increase is fully in line with the size reduction
  389. //else {
  390. // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
  391. //}
  392. bool convert_incompatible_tensor = false;
  393. {
  394. const int64_t nx = tensor->ne[0];
  395. const int64_t ny = tensor->ne[1];
  396. const int64_t qk_k = ggml_blck_size(new_type);
  397. if (nx % qk_k != 0) {
  398. LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
  399. convert_incompatible_tensor = true;
  400. } else {
  401. ++qs.n_k_quantized;
  402. }
  403. }
  404. if (convert_incompatible_tensor) {
  405. switch (new_type) {
  406. case GGML_TYPE_TQ1_0:
  407. case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
  408. case GGML_TYPE_IQ2_XXS:
  409. case GGML_TYPE_IQ2_XS:
  410. case GGML_TYPE_IQ2_S:
  411. case GGML_TYPE_IQ3_XXS:
  412. case GGML_TYPE_IQ3_S:
  413. case GGML_TYPE_IQ1_S:
  414. case GGML_TYPE_IQ1_M:
  415. case GGML_TYPE_Q2_K:
  416. case GGML_TYPE_Q3_K:
  417. case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
  418. case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
  419. case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
  420. case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
  421. default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
  422. }
  423. if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
  424. new_type = GGML_TYPE_F16;
  425. }
  426. LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
  427. ++qs.n_fallback;
  428. }
  429. return new_type;
  430. }
  431. static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
  432. if (nthread < 2) {
  433. // single-thread
  434. size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
  435. if (!ggml_validate_row_data(new_type, new_data, new_size)) {
  436. throw std::runtime_error("quantized data validation failed");
  437. }
  438. return new_size;
  439. }
  440. std::mutex mutex;
  441. int64_t counter = 0;
  442. size_t new_size = 0;
  443. bool valid = true;
  444. auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
  445. nrows, n_per_row, imatrix]() {
  446. const int64_t nrows_per_chunk = chunk_size / n_per_row;
  447. size_t local_size = 0;
  448. while (true) {
  449. std::unique_lock<std::mutex> lock(mutex);
  450. int64_t first_row = counter; counter += nrows_per_chunk;
  451. if (first_row >= nrows) {
  452. if (local_size > 0) {
  453. new_size += local_size;
  454. }
  455. break;
  456. }
  457. lock.unlock();
  458. const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
  459. size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
  460. local_size += this_size;
  461. // validate the quantized data
  462. const size_t row_size = ggml_row_size(new_type, n_per_row);
  463. void * this_data = (char *) new_data + first_row * row_size;
  464. if (!ggml_validate_row_data(new_type, this_data, this_size)) {
  465. std::unique_lock<std::mutex> lock(mutex);
  466. valid = false;
  467. break;
  468. }
  469. }
  470. };
  471. for (int it = 0; it < nthread - 1; ++it) {
  472. workers.emplace_back(compute);
  473. }
  474. compute();
  475. for (auto & w : workers) { w.join(); }
  476. workers.clear();
  477. if (!valid) {
  478. throw std::runtime_error("quantized data validation failed");
  479. }
  480. return new_size;
  481. }
  482. static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
  483. ggml_type default_type;
  484. llama_ftype ftype = params->ftype;
  485. switch (params->ftype) {
  486. case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
  487. case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
  488. case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
  489. case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
  490. case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
  491. case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
  492. case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
  493. case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
  494. // K-quants
  495. case LLAMA_FTYPE_MOSTLY_Q2_K_S:
  496. case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
  497. case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
  498. case LLAMA_FTYPE_MOSTLY_Q3_K_S:
  499. case LLAMA_FTYPE_MOSTLY_Q3_K_M:
  500. case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
  501. case LLAMA_FTYPE_MOSTLY_Q4_K_S:
  502. case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
  503. case LLAMA_FTYPE_MOSTLY_Q5_K_S:
  504. case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
  505. case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
  506. case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
  507. case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
  508. case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
  509. case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
  510. case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
  511. case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
  512. case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
  513. case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
  514. case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
  515. case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
  516. case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
  517. case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
  518. case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
  519. default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
  520. }
  521. int nthread = params->nthread;
  522. if (nthread <= 0) {
  523. nthread = std::thread::hardware_concurrency();
  524. }
  525. // mmap consistently increases speed on Linux, and also increases speed on Windows with
  526. // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
  527. #if defined(__linux__) || defined(_WIN32)
  528. constexpr bool use_mmap = true;
  529. #else
  530. constexpr bool use_mmap = false;
  531. #endif
  532. llama_model_kv_override * kv_overrides = nullptr;
  533. if (params->kv_overrides) {
  534. auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
  535. kv_overrides = v->data();
  536. }
  537. std::vector<std::string> splits = {};
  538. llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
  539. ml.init_mappings(false); // no prefetching
  540. llama_model model(llama_model_default_params());
  541. model.load_arch (ml);
  542. model.load_hparams(ml);
  543. model.load_stats (ml);
  544. quantize_state_impl qs(model, params);
  545. if (params->only_copy) {
  546. ftype = ml.ftype;
  547. }
  548. const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
  549. if (params->imatrix) {
  550. imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
  551. if (imatrix_data) {
  552. LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
  553. qs.has_imatrix = true;
  554. // check imatrix for nans or infs
  555. for (const auto & kv : *imatrix_data) {
  556. for (float f : kv.second) {
  557. if (!std::isfinite(f)) {
  558. throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
  559. }
  560. }
  561. }
  562. }
  563. }
  564. const size_t align = GGUF_DEFAULT_ALIGNMENT;
  565. gguf_context_ptr ctx_out { gguf_init_empty() };
  566. std::vector<int> prune_list = {};
  567. if (params->prune_layers) {
  568. prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
  569. }
  570. // copy the KV pairs from the input file
  571. gguf_set_kv (ctx_out.get(), ml.meta.get());
  572. gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
  573. gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
  574. // Remove split metadata
  575. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
  576. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
  577. gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
  578. if (params->kv_overrides) {
  579. const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
  580. for (const auto & o : overrides) {
  581. if (o.key[0] == 0) break;
  582. if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
  583. gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
  584. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
  585. // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
  586. gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
  587. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
  588. gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
  589. } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
  590. gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
  591. } else {
  592. LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
  593. }
  594. }
  595. }
  596. std::map<int, std::string> mapped;
  597. int blk_id = 0;
  598. int pruned_attention_w = 0;
  599. // make a list of weights
  600. std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
  601. tensors.reserve(ml.weights_map.size());
  602. for (const auto & it : ml.weights_map) {
  603. const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
  604. if (remapped_name.empty()) {
  605. if (it.first.find("attn_v.weight") != std::string::npos ||
  606. it.first.find("attn_qkv.weight") != std::string::npos ||
  607. it.first.find("attn_kv_b.weight") != std::string::npos) {
  608. pruned_attention_w++;
  609. }
  610. LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
  611. continue;
  612. } else if (remapped_name != it.first) {
  613. ggml_set_name(it.second.tensor, remapped_name.c_str());
  614. LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
  615. }
  616. tensors.push_back(&it.second);
  617. }
  618. if (!prune_list.empty()) {
  619. gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
  620. }
  621. // keep_split requires that the weights are sorted by split index
  622. if (params->keep_split) {
  623. std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
  624. if (a->idx == b->idx) {
  625. return a->offs < b->offs;
  626. }
  627. return a->idx < b->idx;
  628. });
  629. }
  630. for (const auto * it : tensors) {
  631. const struct ggml_tensor * tensor = it->tensor;
  632. const std::string name = ggml_get_name(tensor);
  633. // TODO: avoid hardcoded tensor names - use the TN_* constants
  634. if (name.find("attn_v.weight") != std::string::npos ||
  635. name.find("attn_qkv.weight") != std::string::npos ||
  636. name.find("attn_kv_b.weight")!= std::string::npos) {
  637. ++qs.n_attention_wv;
  638. } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
  639. qs.has_output = true;
  640. }
  641. }
  642. qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
  643. // sanity checks for models that have attention layers
  644. if (qs.n_attention_wv != 0)
  645. {
  646. const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
  647. // attention layers have a non-zero number of kv heads
  648. int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
  649. if (llama_model_has_encoder(&model)) {
  650. n_attn_layer *= 3;
  651. }
  652. GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
  653. }
  654. size_t total_size_org = 0;
  655. size_t total_size_new = 0;
  656. std::vector<std::thread> workers;
  657. workers.reserve(nthread);
  658. int idx = 0;
  659. std::vector<no_init<uint8_t>> read_data;
  660. std::vector<no_init<uint8_t>> work;
  661. std::vector<no_init<float>> f32_conv_buf;
  662. uint16_t n_split = 1;
  663. // Assume split index is continuous
  664. if (params->keep_split) {
  665. for (const auto * it : tensors) {
  666. n_split = std::max(uint16_t(it->idx + 1), n_split);
  667. }
  668. }
  669. std::vector<gguf_context_ptr> ctx_outs(n_split);
  670. ctx_outs[0] = std::move(ctx_out);
  671. // populate the original tensors so we get an initial meta data
  672. for (const auto * it : tensors) {
  673. uint16_t i_split = params->keep_split ? it->idx : 0;
  674. ggml_tensor * tensor = it->tensor;
  675. if (!ctx_outs[i_split]) {
  676. ctx_outs[i_split].reset(gguf_init_empty());
  677. }
  678. gguf_add_tensor(ctx_outs[i_split].get(), tensor);
  679. }
  680. // Set split info if needed
  681. if (n_split > 1) {
  682. for (size_t i = 0; i < ctx_outs.size(); ++i) {
  683. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
  684. gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
  685. gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
  686. }
  687. }
  688. int cur_split = -1;
  689. std::ofstream fout;
  690. auto close_ofstream = [&]() {
  691. // Write metadata and close file handler
  692. if (fout.is_open()) {
  693. fout.seekp(0);
  694. std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
  695. gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
  696. fout.write((const char *) data.data(), data.size());
  697. fout.close();
  698. }
  699. };
  700. auto new_ofstream = [&](int index) {
  701. cur_split = index;
  702. GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
  703. std::string fname = fname_out;
  704. if (params->keep_split) {
  705. std::vector<char> split_path(llama_path_max(), 0);
  706. llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
  707. fname = std::string(split_path.data());
  708. }
  709. fout = std::ofstream(fname, std::ios::binary);
  710. fout.exceptions(std::ofstream::failbit); // fail fast on write errors
  711. const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
  712. // placeholder for the meta data
  713. ::zeros(fout, meta_size);
  714. };
  715. const auto tn = LLM_TN(model.arch);
  716. new_ofstream(0);
  717. for (const auto * it : tensors) {
  718. const auto & weight = *it;
  719. ggml_tensor * tensor = weight.tensor;
  720. if (weight.idx != cur_split && params->keep_split) {
  721. close_ofstream();
  722. new_ofstream(weight.idx);
  723. }
  724. const std::string name = ggml_get_name(tensor);
  725. if (!ml.use_mmap) {
  726. if (read_data.size() < ggml_nbytes(tensor)) {
  727. read_data.resize(ggml_nbytes(tensor));
  728. }
  729. tensor->data = read_data.data();
  730. }
  731. ml.load_data_for(tensor);
  732. LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
  733. ++idx, ml.n_tensors,
  734. ggml_get_name(tensor),
  735. llama_format_tensor_shape(tensor).c_str(),
  736. ggml_type_name(tensor->type));
  737. // This used to be a regex, but <regex> has an extreme cost to compile times.
  738. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
  739. // quantize only 2D and 3D tensors (experts)
  740. quantize &= (ggml_n_dims(tensor) >= 2);
  741. // do not quantize norm tensors
  742. quantize &= name.find("_norm.weight") == std::string::npos;
  743. quantize &= params->quantize_output_tensor || name != "output.weight";
  744. quantize &= !params->only_copy;
  745. // do not quantize expert gating tensors
  746. // NOTE: can't use LLM_TN here because the layer number is not known
  747. quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
  748. // these are very small (e.g. 4x4)
  749. quantize &= name.find("altup") == std::string::npos;
  750. quantize &= name.find("laurel") == std::string::npos;
  751. // these are not too big so keep them as it is
  752. quantize &= name.find("per_layer_model_proj") == std::string::npos;
  753. // do not quantize positional embeddings and token types (BERT)
  754. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
  755. quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
  756. // do not quantize Mamba's small yet 2D weights
  757. // NOTE: can't use LLM_TN here because the layer number is not known
  758. quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
  759. quantize &= name.find("shortconv.conv.weight") == std::string::npos;
  760. // do not quantize RWKV's small yet 2D weights
  761. quantize &= name.find("time_mix_first.weight") == std::string::npos;
  762. quantize &= name.find("time_mix_w0.weight") == std::string::npos;
  763. quantize &= name.find("time_mix_w1.weight") == std::string::npos;
  764. quantize &= name.find("time_mix_w2.weight") == std::string::npos;
  765. quantize &= name.find("time_mix_v0.weight") == std::string::npos;
  766. quantize &= name.find("time_mix_v1.weight") == std::string::npos;
  767. quantize &= name.find("time_mix_v2.weight") == std::string::npos;
  768. quantize &= name.find("time_mix_a0.weight") == std::string::npos;
  769. quantize &= name.find("time_mix_a1.weight") == std::string::npos;
  770. quantize &= name.find("time_mix_a2.weight") == std::string::npos;
  771. quantize &= name.find("time_mix_g1.weight") == std::string::npos;
  772. quantize &= name.find("time_mix_g2.weight") == std::string::npos;
  773. quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
  774. quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
  775. quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
  776. // do not quantize relative position bias (T5)
  777. quantize &= name.find("attn_rel_b.weight") == std::string::npos;
  778. ggml_type new_type;
  779. void * new_data;
  780. size_t new_size;
  781. if (quantize) {
  782. new_type = default_type;
  783. // get more optimal quantization type based on the tensor shape, layer, etc.
  784. if (!params->pure && ggml_is_quantized(default_type)) {
  785. new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
  786. // unless the user specifies a type
  787. if (params->tensor_types) {
  788. const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
  789. const std::string tensor_name(tensor->name);
  790. for (const auto & [tname, qtype] : tensor_types) {
  791. if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
  792. if (qtype != new_type) {
  793. LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
  794. new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
  795. }
  796. }
  797. }
  798. }
  799. }
  800. if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
  801. new_type = params->token_embedding_type;
  802. }
  803. if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
  804. new_type = params->output_tensor_type;
  805. }
  806. // If we've decided to quantize to the same type the tensor is already
  807. // in then there's nothing to do.
  808. quantize = tensor->type != new_type;
  809. }
  810. if (!quantize) {
  811. new_type = tensor->type;
  812. new_data = tensor->data;
  813. new_size = ggml_nbytes(tensor);
  814. LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
  815. } else {
  816. const int64_t nelements = ggml_nelements(tensor);
  817. const float * imatrix = nullptr;
  818. if (imatrix_data) {
  819. auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
  820. if (it == imatrix_data->end()) {
  821. LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
  822. } else {
  823. if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
  824. imatrix = it->second.data();
  825. } else {
  826. LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
  827. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
  828. // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
  829. // this is a significant error and it may be good idea to abort the process if this happens,
  830. // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
  831. // tok_embd should be ignored in this case, since it always causes this warning
  832. if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
  833. throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
  834. int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
  835. }
  836. }
  837. }
  838. }
  839. if ((new_type == GGML_TYPE_IQ2_XXS ||
  840. new_type == GGML_TYPE_IQ2_XS ||
  841. new_type == GGML_TYPE_IQ2_S ||
  842. new_type == GGML_TYPE_IQ1_S ||
  843. (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
  844. (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
  845. LLAMA_LOG_ERROR("\n\n============================================================\n");
  846. LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
  847. LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
  848. LLAMA_LOG_ERROR("============================================================\n\n");
  849. throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
  850. }
  851. float * f32_data;
  852. if (tensor->type == GGML_TYPE_F32) {
  853. f32_data = (float *) tensor->data;
  854. } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
  855. throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
  856. } else {
  857. llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
  858. f32_data = (float *) f32_conv_buf.data();
  859. }
  860. LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
  861. fflush(stdout);
  862. if (work.size() < (size_t)nelements * 4) {
  863. work.resize(nelements * 4); // upper bound on size
  864. }
  865. new_data = work.data();
  866. const int64_t n_per_row = tensor->ne[0];
  867. const int64_t nrows = tensor->ne[1];
  868. static const int64_t min_chunk_size = 32 * 512;
  869. const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
  870. const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
  871. const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
  872. const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
  873. // quantize each expert separately since they have different importance matrices
  874. new_size = 0;
  875. for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
  876. const float * f32_data_03 = f32_data + i03 * nelements_matrix;
  877. void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
  878. const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
  879. new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
  880. }
  881. LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
  882. }
  883. total_size_org += ggml_nbytes(tensor);
  884. total_size_new += new_size;
  885. // update the gguf meta data as we go
  886. gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
  887. GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
  888. gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
  889. // write tensor data + padding
  890. fout.write((const char *) new_data, new_size);
  891. zeros(fout, GGML_PAD(new_size, align) - new_size);
  892. }
  893. close_ofstream();
  894. LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
  895. LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
  896. if (qs.n_fallback > 0) {
  897. LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
  898. __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
  899. }
  900. }
  901. //
  902. // interface implementation
  903. //
  904. llama_model_quantize_params llama_model_quantize_default_params() {
  905. llama_model_quantize_params result = {
  906. /*.nthread =*/ 0,
  907. /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
  908. /*.output_tensor_type =*/ GGML_TYPE_COUNT,
  909. /*.token_embedding_type =*/ GGML_TYPE_COUNT,
  910. /*.allow_requantize =*/ false,
  911. /*.quantize_output_tensor =*/ true,
  912. /*.only_copy =*/ false,
  913. /*.pure =*/ false,
  914. /*.keep_split =*/ false,
  915. /*.imatrix =*/ nullptr,
  916. /*.kv_overrides =*/ nullptr,
  917. /*.tensor_type =*/ nullptr,
  918. /*.prune_layers =*/ nullptr
  919. };
  920. return result;
  921. }
  922. uint32_t llama_model_quantize(
  923. const char * fname_inp,
  924. const char * fname_out,
  925. const llama_model_quantize_params * params) {
  926. try {
  927. llama_model_quantize_impl(fname_inp, fname_out, params);
  928. } catch (const std::exception & err) {
  929. LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
  930. return 1;
  931. }
  932. return 0;
  933. }