convert-hf-to-gguf.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. # check for any of the given keys in the dictionary and return the value of the first key found
  20. def get_key_opts(d, keys):
  21. for k in keys:
  22. if k in d:
  23. return d[k]
  24. print(f"Could not find any of {keys}")
  25. sys.exit()
  26. ###### MODEL DEFINITIONS ######
  27. class SentencePieceTokenTypes(IntEnum):
  28. NORMAL = 1
  29. UNKNOWN = 2
  30. CONTROL = 3
  31. USER_DEFINED = 4
  32. UNUSED = 5
  33. BYTE = 6
  34. class Model:
  35. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  36. self.dir_model = dir_model
  37. self.ftype = ftype
  38. self.fname_out = fname_out
  39. self.is_big_endian = is_big_endian
  40. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  41. self.is_safetensors = self._is_model_safetensors()
  42. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  43. self.part_names = self._get_part_names()
  44. self.hparams = Model.load_hparams(self.dir_model)
  45. self.model_arch = self._get_model_architecture()
  46. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  47. def set_vocab(self):
  48. self._set_vocab_gpt2()
  49. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  50. for part_name in self.part_names:
  51. print(f"gguf: loading model part '{part_name}'")
  52. ctx: ContextManager[Any]
  53. if self.is_safetensors:
  54. from safetensors import safe_open
  55. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  56. else:
  57. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  58. with ctx as model_part:
  59. for name in model_part.keys():
  60. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  61. yield name, data
  62. def set_gguf_parameters(self):
  63. self.gguf_writer.add_name(self.dir_model.name)
  64. self.gguf_writer.add_block_count(self.hparams.get(
  65. "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
  66. ))
  67. if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
  68. self.gguf_writer.add_context_length(n_ctx)
  69. if (n_embd := self.hparams.get("hidden_size")) is not None:
  70. self.gguf_writer.add_embedding_length(n_embd)
  71. if (n_ff := self.hparams.get("intermediate_size")) is not None:
  72. self.gguf_writer.add_feed_forward_length(n_ff)
  73. if (n_head := self.hparams.get("num_attention_heads")) is not None:
  74. self.gguf_writer.add_head_count(n_head)
  75. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  76. self.gguf_writer.add_head_count_kv(n_head_kv)
  77. if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  78. self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
  79. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  80. self.gguf_writer.add_expert_count(n_experts)
  81. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  82. self.gguf_writer.add_expert_used_count(n_experts_used)
  83. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  84. def write_tensors(self):
  85. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  86. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  87. for name, data_torch in self.get_tensors():
  88. # we don't need these
  89. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  90. continue
  91. old_dtype = data_torch.dtype
  92. # convert any unsupported data types to float32
  93. if data_torch.dtype not in (torch.float16, torch.float32):
  94. data_torch = data_torch.to(torch.float32)
  95. data = data_torch.squeeze().numpy()
  96. # map tensor names
  97. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  98. if new_name is None:
  99. print(f"Can not map tensor {name!r}")
  100. sys.exit()
  101. n_dims = len(data.shape)
  102. data_dtype = data.dtype
  103. # if f32 desired, convert any float16 to float32
  104. if self.ftype == 0 and data_dtype == np.float16:
  105. data = data.astype(np.float32)
  106. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  107. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  108. data = data.astype(np.float32)
  109. # if f16 desired, convert any float32 2-dim weight tensors to float16
  110. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  111. data = data.astype(np.float16)
  112. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  113. self.gguf_writer.add_tensor(new_name, data)
  114. def write(self):
  115. self.write_tensors()
  116. self.gguf_writer.write_header_to_file()
  117. self.gguf_writer.write_kv_data_to_file()
  118. self.gguf_writer.write_tensors_to_file()
  119. self.gguf_writer.close()
  120. def write_vocab(self):
  121. self.gguf_writer.write_header_to_file()
  122. self.gguf_writer.write_kv_data_to_file()
  123. self.gguf_writer.close()
  124. @staticmethod
  125. def count_model_parts(dir_model: Path, prefix: str) -> int:
  126. num_parts = 0
  127. for filename in os.listdir(dir_model):
  128. if filename.endswith(prefix):
  129. num_parts += 1
  130. return num_parts
  131. @staticmethod
  132. def load_hparams(dir_model):
  133. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  134. return json.load(f)
  135. @staticmethod
  136. def from_model_architecture(model_architecture):
  137. if model_architecture == "GPTNeoXForCausalLM":
  138. return GPTNeoXModel
  139. if model_architecture == "BloomForCausalLM":
  140. return BloomModel
  141. if model_architecture == "MPTForCausalLM":
  142. return MPTModel
  143. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  144. return BaichuanModel
  145. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  146. return FalconModel
  147. if model_architecture == "GPTBigCodeForCausalLM":
  148. return StarCoderModel
  149. if model_architecture == "GPTRefactForCausalLM":
  150. return RefactModel
  151. if model_architecture == "PersimmonForCausalLM":
  152. return PersimmonModel
  153. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  154. return StableLMModel
  155. if model_architecture == "QWenLMHeadModel":
  156. return QwenModel
  157. if model_architecture == "Qwen2ForCausalLM":
  158. return Model
  159. if model_architecture == "MixtralForCausalLM":
  160. return MixtralModel
  161. if model_architecture == "GPT2LMHeadModel":
  162. return GPT2Model
  163. if model_architecture == "PhiForCausalLM":
  164. return Phi2Model
  165. if model_architecture == "PlamoForCausalLM":
  166. return PlamoModel
  167. if model_architecture == "CodeShellForCausalLM":
  168. return CodeShellModel
  169. if model_architecture == "OrionForCausalLM":
  170. return OrionModel
  171. return Model
  172. def _is_model_safetensors(self) -> bool:
  173. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  174. def _get_part_names(self):
  175. if self.is_safetensors:
  176. if self.num_parts == 1: # there's only one .safetensors file
  177. return ("model.safetensors",)
  178. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  179. if self.num_parts == 1: # there's only one .bin file
  180. return ("pytorch_model.bin",)
  181. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  182. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  183. arch = self.hparams["architectures"][0]
  184. if arch == "GPTNeoXForCausalLM":
  185. return gguf.MODEL_ARCH.GPTNEOX
  186. if arch == "BloomForCausalLM":
  187. return gguf.MODEL_ARCH.BLOOM
  188. if arch == "MPTForCausalLM":
  189. return gguf.MODEL_ARCH.MPT
  190. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  191. return gguf.MODEL_ARCH.BAICHUAN
  192. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  193. return gguf.MODEL_ARCH.FALCON
  194. if arch == "GPTBigCodeForCausalLM":
  195. return gguf.MODEL_ARCH.STARCODER
  196. if arch == "GPTRefactForCausalLM":
  197. return gguf.MODEL_ARCH.REFACT
  198. if arch == "PersimmonForCausalLM":
  199. return gguf.MODEL_ARCH.PERSIMMON
  200. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  201. return gguf.MODEL_ARCH.STABLELM
  202. if arch == "QWenLMHeadModel":
  203. return gguf.MODEL_ARCH.QWEN
  204. if arch == "Qwen2ForCausalLM":
  205. return gguf.MODEL_ARCH.QWEN2
  206. if arch == "MixtralForCausalLM":
  207. return gguf.MODEL_ARCH.LLAMA
  208. if arch == "GPT2LMHeadModel":
  209. return gguf.MODEL_ARCH.GPT2
  210. if arch == "PhiForCausalLM":
  211. return gguf.MODEL_ARCH.PHI2
  212. if arch == "PlamoForCausalLM":
  213. return gguf.MODEL_ARCH.PLAMO
  214. if arch == "CodeShellForCausalLM":
  215. return gguf.MODEL_ARCH.CODESHELL
  216. if arch == "OrionForCausalLM":
  217. return gguf.MODEL_ARCH.ORION
  218. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  219. def _set_vocab_gpt2(self):
  220. dir_model = self.dir_model
  221. hparams = self.hparams
  222. tokens: list[bytearray] = []
  223. toktypes: list[int] = []
  224. from transformers import AutoTokenizer
  225. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  226. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  227. assert max(tokenizer.vocab.values()) < vocab_size
  228. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  229. added_vocab = tokenizer.get_added_vocab()
  230. for i in range(vocab_size):
  231. if i not in reverse_vocab:
  232. pad_token = f"[PAD{i}]".encode('utf-8')
  233. tokens.append(bytearray(pad_token))
  234. toktypes.append(gguf.TokenType.USER_DEFINED)
  235. elif reverse_vocab[i] in added_vocab:
  236. tokens.append(reverse_vocab[i])
  237. if tokenizer.added_tokens_decoder[i].special:
  238. toktypes.append(gguf.TokenType.CONTROL)
  239. else:
  240. toktypes.append(gguf.TokenType.USER_DEFINED)
  241. else:
  242. tokens.append(reverse_vocab[i])
  243. toktypes.append(gguf.TokenType.NORMAL)
  244. self.gguf_writer.add_tokenizer_model("gpt2")
  245. self.gguf_writer.add_token_list(tokens)
  246. self.gguf_writer.add_token_types(toktypes)
  247. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  248. special_vocab.add_to_gguf(self.gguf_writer)
  249. def _set_vocab_qwen(self):
  250. dir_model = self.dir_model
  251. hparams = self.hparams
  252. tokens: list[bytearray] = []
  253. toktypes: list[int] = []
  254. from transformers import AutoTokenizer
  255. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  256. vocab_size = hparams["vocab_size"]
  257. assert max(tokenizer.get_vocab().values()) < vocab_size
  258. merges = []
  259. vocab = {}
  260. mergeable_ranks = tokenizer.mergeable_ranks
  261. for token, rank in mergeable_ranks.items():
  262. vocab[QwenModel.token_bytes_to_string(token)] = rank
  263. if len(token) == 1:
  264. continue
  265. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  266. assert len(merged) == 2
  267. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  268. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  269. added_vocab = tokenizer.special_tokens
  270. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
  271. for i in range(vocab_size):
  272. if i not in reverse_vocab:
  273. pad_token = f"[PAD{i}]".encode("utf-8")
  274. tokens.append(bytearray(pad_token))
  275. toktypes.append(gguf.TokenType.USER_DEFINED)
  276. elif reverse_vocab[i] in added_vocab:
  277. tokens.append(reverse_vocab[i])
  278. toktypes.append(gguf.TokenType.CONTROL)
  279. else:
  280. tokens.append(reverse_vocab[i])
  281. toktypes.append(gguf.TokenType.NORMAL)
  282. self.gguf_writer.add_tokenizer_model("gpt2")
  283. self.gguf_writer.add_token_list(tokens)
  284. self.gguf_writer.add_token_types(toktypes)
  285. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  286. special_vocab.merges = merges
  287. # only add special tokens when they were not already loaded from config.json
  288. if len(special_vocab.special_token_ids) == 0:
  289. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  290. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  291. # this one is usually not in config.json anyway
  292. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  293. special_vocab.add_to_gguf(self.gguf_writer)
  294. def _set_vocab_sentencepiece(self):
  295. from sentencepiece import SentencePieceProcessor
  296. tokenizer_path = self.dir_model / 'tokenizer.model'
  297. tokens: list[bytes] = []
  298. scores: list[float] = []
  299. toktypes: list[int] = []
  300. if not tokenizer_path.is_file():
  301. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  302. sys.exit(1)
  303. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  304. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  305. for token_id in range(vocab_size):
  306. piece = tokenizer.id_to_piece(token_id)
  307. text = piece.encode("utf-8")
  308. score = tokenizer.get_score(token_id)
  309. toktype = SentencePieceTokenTypes.NORMAL
  310. if tokenizer.is_unknown(token_id):
  311. toktype = SentencePieceTokenTypes.UNKNOWN
  312. elif tokenizer.is_control(token_id):
  313. toktype = SentencePieceTokenTypes.CONTROL
  314. elif tokenizer.is_unused(token_id):
  315. toktype = SentencePieceTokenTypes.UNUSED
  316. elif tokenizer.is_byte(token_id):
  317. toktype = SentencePieceTokenTypes.BYTE
  318. tokens.append(text)
  319. scores.append(score)
  320. toktypes.append(toktype)
  321. added_tokens_file = self.dir_model / 'added_tokens.json'
  322. if added_tokens_file.is_file():
  323. with open(added_tokens_file, "r", encoding="utf-8") as f:
  324. added_tokens_json = json.load(f)
  325. for key in added_tokens_json:
  326. tokens.append(key.encode("utf-8"))
  327. scores.append(-1000.0)
  328. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  329. self.gguf_writer.add_tokenizer_model("llama")
  330. self.gguf_writer.add_token_list(tokens)
  331. self.gguf_writer.add_token_scores(scores)
  332. self.gguf_writer.add_token_types(toktypes)
  333. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  334. special_vocab.add_to_gguf(self.gguf_writer)
  335. class GPTNeoXModel(Model):
  336. def set_gguf_parameters(self):
  337. block_count = self.hparams["num_hidden_layers"]
  338. self.gguf_writer.add_name(self.dir_model.name)
  339. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  340. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  341. self.gguf_writer.add_block_count(block_count)
  342. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  343. self.gguf_writer.add_rope_dimension_count(
  344. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  345. )
  346. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  347. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  348. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  349. class BloomModel(Model):
  350. def set_gguf_parameters(self):
  351. self.gguf_writer.add_name("Bloom")
  352. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  353. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  354. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  355. self.gguf_writer.add_embedding_length(n_embed)
  356. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  357. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  358. self.gguf_writer.add_head_count(n_head)
  359. self.gguf_writer.add_head_count_kv(n_head)
  360. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  361. self.gguf_writer.add_file_type(self.ftype)
  362. def write_tensors(self):
  363. block_count = self.hparams["n_layer"]
  364. tensors = dict(self.get_tensors())
  365. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  366. has_lm_head = True
  367. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  368. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  369. for name, data_torch in tensors.items():
  370. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  371. has_lm_head = False
  372. name = re.sub(r'transformer\.', '', name)
  373. old_dtype = data_torch.dtype
  374. # convert any unsupported data types to float32
  375. if data_torch.dtype not in (torch.float16, torch.float32):
  376. data_torch = data_torch.to(torch.float32)
  377. data = data_torch.squeeze().numpy()
  378. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  379. # Map bloom-style qkv_linear to gpt-style qkv_linear
  380. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  381. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  382. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  383. data = np.concatenate(
  384. (
  385. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  386. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  387. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  388. ),
  389. axis=0,
  390. )
  391. print("re-format attention.linear_qkv.weight")
  392. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  393. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  394. data = np.concatenate(
  395. (
  396. qkv_bias[:, 0, :].reshape((n_embed,)),
  397. qkv_bias[:, 1, :].reshape((n_embed,)),
  398. qkv_bias[:, 2, :].reshape((n_embed,)),
  399. ),
  400. axis=0,
  401. )
  402. print("re-format attention.linear_qkv.bias")
  403. # map tensor names
  404. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  405. if new_name is None:
  406. print(f"Can not map tensor {name!r}")
  407. sys.exit()
  408. n_dims = len(data.shape)
  409. data_dtype = data.dtype
  410. # if f32 desired, convert any float16 to float32
  411. if self.ftype == 0 and data_dtype == np.float16:
  412. data = data.astype(np.float32)
  413. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  414. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  415. data = data.astype(np.float32)
  416. # if f16 desired, convert any float32 2-dim weight tensors to float16
  417. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  418. data = data.astype(np.float16)
  419. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  420. self.gguf_writer.add_tensor(new_name, data)
  421. if not has_lm_head and name == "word_embeddings.weight":
  422. self.gguf_writer.add_tensor("output.weight", data)
  423. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  424. class MPTModel(Model):
  425. def set_gguf_parameters(self):
  426. block_count = self.hparams["n_layers"]
  427. self.gguf_writer.add_name(self.dir_model.name)
  428. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  429. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  430. self.gguf_writer.add_block_count(block_count)
  431. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  432. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  433. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  434. self.gguf_writer.add_head_count_kv(kv_n_heads)
  435. self.gguf_writer.add_layer_norm_eps(1e-5)
  436. if self.hparams["attn_config"]["clip_qkv"] is not None:
  437. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  438. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  439. def write_tensors(self):
  440. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  441. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  442. for name, data_torch in self.get_tensors():
  443. # we don't need these
  444. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  445. continue
  446. old_dtype = data_torch.dtype
  447. # convert any unsupported data types to float32
  448. if data_torch.dtype not in (torch.float16, torch.float32):
  449. data_torch = data_torch.to(torch.float32)
  450. data = data_torch.squeeze().numpy()
  451. # map tensor names
  452. if "scales" in name:
  453. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  454. if new_name is not None:
  455. new_name = new_name.replace("scales", "act.scales")
  456. else:
  457. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  458. if new_name is None:
  459. print(f"Can not map tensor {name!r}")
  460. sys.exit()
  461. n_dims = len(data.shape)
  462. data_dtype = data.dtype
  463. # if f32 desired, convert any float16 to float32
  464. if self.ftype == 0 and data_dtype == np.float16:
  465. data = data.astype(np.float32)
  466. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  467. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  468. data = data.astype(np.float32)
  469. # if f16 desired, convert any float32 2-dim weight tensors to float16
  470. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  471. data = data.astype(np.float16)
  472. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  473. self.gguf_writer.add_tensor(new_name, data)
  474. # note: MPT output is tied to (same as) wte in original model;
  475. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  476. if new_name == "token_embd.weight":
  477. self.gguf_writer.add_tensor("output.weight", data)
  478. class OrionModel(Model):
  479. def set_vocab(self):
  480. self._set_vocab_sentencepiece()
  481. def set_gguf_parameters(self):
  482. block_count = self.hparams["num_hidden_layers"]
  483. head_count = self.hparams["num_attention_heads"]
  484. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  485. hf_repo = self.hparams.get("_name_or_path", "")
  486. ctx_length = 0
  487. if "max_sequence_length" in self.hparams:
  488. ctx_length = self.hparams["max_sequence_length"]
  489. elif "max_position_embeddings" in self.hparams:
  490. ctx_length = self.hparams["max_position_embeddings"]
  491. elif "model_max_length" in self.hparams:
  492. ctx_length = self.hparams["model_max_length"]
  493. else:
  494. print("gguf: can not find ctx length parameter.")
  495. sys.exit()
  496. self.gguf_writer.add_file_type(self.ftype)
  497. self.gguf_writer.add_name(self.dir_model.name)
  498. self.gguf_writer.add_source_hf_repo(hf_repo)
  499. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  500. self.gguf_writer.add_context_length(ctx_length)
  501. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  502. self.gguf_writer.add_block_count(block_count)
  503. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  504. self.gguf_writer.add_head_count(head_count)
  505. self.gguf_writer.add_head_count_kv(head_count_kv)
  506. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  507. def write_tensors(self):
  508. # Collect tensors from generator object
  509. model_kv = dict(self.get_tensors())
  510. block_count = self.hparams["num_hidden_layers"]
  511. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  512. for name, data_torch in model_kv.items():
  513. # we don't need these
  514. if name.endswith(".rotary_emb.inv_freq"):
  515. continue
  516. old_dtype = data_torch.dtype
  517. # convert any unsupported data types to float32
  518. if data_torch.dtype not in (torch.float16, torch.float32):
  519. data_torch = data_torch.to(torch.float32)
  520. data = data_torch.squeeze().numpy()
  521. # map tensor names
  522. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  523. if new_name is None:
  524. print(f"Can not map tensor {name!r}")
  525. sys.exit()
  526. n_dims = len(data.shape)
  527. data_dtype = data.dtype
  528. # if f32 desired, convert any float16 to float32
  529. if self.ftype == 0 and data_dtype == np.float16:
  530. data = data.astype(np.float32)
  531. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  532. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  533. data = data.astype(np.float32)
  534. # if f16 desired, convert any float32 2-dim weight tensors to float16
  535. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  536. data = data.astype(np.float16)
  537. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  538. self.gguf_writer.add_tensor(new_name, data)
  539. class BaichuanModel(Model):
  540. def set_vocab(self):
  541. self._set_vocab_sentencepiece()
  542. def set_gguf_parameters(self):
  543. block_count = self.hparams["num_hidden_layers"]
  544. head_count = self.hparams["num_attention_heads"]
  545. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  546. hf_repo = self.hparams.get("_name_or_path", "")
  547. ctx_length = 0
  548. if "max_sequence_length" in self.hparams:
  549. ctx_length = self.hparams["max_sequence_length"]
  550. elif "max_position_embeddings" in self.hparams:
  551. ctx_length = self.hparams["max_position_embeddings"]
  552. elif "model_max_length" in self.hparams:
  553. ctx_length = self.hparams["model_max_length"]
  554. else:
  555. print("gguf: can not find ctx length parameter.")
  556. sys.exit()
  557. self.gguf_writer.add_name(self.dir_model.name)
  558. self.gguf_writer.add_source_hf_repo(hf_repo)
  559. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  560. self.gguf_writer.add_context_length(ctx_length)
  561. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  562. self.gguf_writer.add_block_count(block_count)
  563. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  564. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  565. self.gguf_writer.add_head_count(head_count)
  566. self.gguf_writer.add_head_count_kv(head_count_kv)
  567. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  568. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  569. if self.hparams["rope_scaling"].get("type") == "linear":
  570. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  571. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  572. def write_tensors(self):
  573. # Collect tensors from generator object
  574. model_kv = dict(self.get_tensors())
  575. block_count = self.hparams["num_hidden_layers"]
  576. head_count = self.hparams["num_attention_heads"]
  577. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  578. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  579. for i in range(block_count):
  580. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  581. print(f"Unpacking and permuting layer {i}")
  582. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  583. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  584. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  585. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  586. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  587. self._reverse_hf_part(w, 2)
  588. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  589. for name, data_torch in model_kv.items():
  590. # we don't need these
  591. if name.endswith(".rotary_emb.inv_freq"):
  592. continue
  593. old_dtype = data_torch.dtype
  594. # convert any unsupported data types to float32
  595. if data_torch.dtype not in (torch.float16, torch.float32):
  596. data_torch = data_torch.to(torch.float32)
  597. data = data_torch.squeeze().numpy()
  598. # map tensor names
  599. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  600. if new_name is None:
  601. print(f"Can not map tensor {name!r}")
  602. sys.exit()
  603. n_dims = len(data.shape)
  604. data_dtype = data.dtype
  605. # if f32 desired, convert any float16 to float32
  606. if self.ftype == 0 and data_dtype == np.float16:
  607. data = data.astype(np.float32)
  608. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  609. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  610. data = data.astype(np.float32)
  611. # if f16 desired, convert any float32 2-dim weight tensors to float16
  612. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  613. data = data.astype(np.float16)
  614. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  615. self.gguf_writer.add_tensor(new_name, data)
  616. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  617. if n_kv_head is not None and n_head != n_kv_head:
  618. n_head //= n_kv_head
  619. return (
  620. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  621. .swapaxes(1, 2)
  622. .reshape(weights.shape)
  623. )
  624. def _reverse_hf_permute_part(
  625. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  626. ) -> Tensor:
  627. r = weights.shape[0] // 3
  628. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  629. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  630. r = weights.shape[0] // 3
  631. return weights[r * n_part:r * n_part + r, ...]
  632. class FalconModel(Model):
  633. def set_gguf_parameters(self):
  634. block_count = self.hparams.get("num_hidden_layers")
  635. if block_count is None:
  636. block_count = self.hparams["n_layer"] # old name
  637. n_head = self.hparams.get("num_attention_heads")
  638. if n_head is None:
  639. n_head = self.hparams["n_head"] # old name
  640. n_head_kv = self.hparams.get("num_kv_heads")
  641. if n_head_kv is None:
  642. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  643. self.gguf_writer.add_name("Falcon")
  644. self.gguf_writer.add_context_length(2048) # not in config.json
  645. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  646. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  647. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  648. self.gguf_writer.add_block_count(block_count)
  649. self.gguf_writer.add_head_count(n_head)
  650. self.gguf_writer.add_head_count_kv(n_head_kv)
  651. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  652. self.gguf_writer.add_file_type(self.ftype)
  653. def write_tensors(self):
  654. block_count = self.hparams.get("num_hidden_layers")
  655. if block_count is None:
  656. block_count = self.hparams["n_layer"] # old name
  657. n_head = self.hparams.get("num_attention_heads")
  658. if n_head is None:
  659. n_head = self.hparams["n_head"] # old name
  660. n_head_kv = self.hparams.get("num_kv_heads")
  661. if n_head_kv is None:
  662. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  663. head_dim = self.hparams["hidden_size"] // n_head
  664. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  665. for name, data_torch in self.get_tensors():
  666. old_dtype = data_torch.dtype
  667. # convert any unsupported data types to float32
  668. if data_torch.dtype not in (torch.float16, torch.float32):
  669. data_torch = data_torch.to(torch.float32)
  670. # QKV tensor transform
  671. # The original query_key_value tensor contains n_head_kv "kv groups",
  672. # each consisting of n_head/n_head_kv query weights followed by one key
  673. # and one value weight (shared by all query heads in the kv group).
  674. # This layout makes it a big pain to work with in GGML.
  675. # So we rearrange them here,, so that we have n_head query weights
  676. # followed by n_head_kv key weights followed by n_head_kv value weights,
  677. # in contiguous fashion.
  678. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  679. if "query_key_value" in name:
  680. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  681. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  682. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  683. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  684. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  685. data = data_torch.squeeze().numpy()
  686. # map tensor names
  687. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  688. if new_name is None:
  689. print(f"Can not map tensor {name!r}")
  690. sys.exit()
  691. n_dims = len(data.shape)
  692. data_dtype = data.dtype
  693. # if f32 desired, convert any float16 to float32
  694. if self.ftype == 0 and data_dtype == np.float16:
  695. data = data.astype(np.float32)
  696. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  697. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  698. data = data.astype(np.float32)
  699. # if f16 desired, convert any float32 2-dim weight tensors to float16
  700. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  701. data = data.astype(np.float16)
  702. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  703. self.gguf_writer.add_tensor(new_name, data)
  704. class StarCoderModel(Model):
  705. def set_gguf_parameters(self):
  706. block_count = self.hparams["n_layer"]
  707. self.gguf_writer.add_name("StarCoder")
  708. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  709. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  710. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  711. self.gguf_writer.add_block_count(block_count)
  712. self.gguf_writer.add_head_count(self.hparams["n_head"])
  713. self.gguf_writer.add_head_count_kv(1)
  714. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  715. self.gguf_writer.add_file_type(self.ftype)
  716. class RefactModel(Model):
  717. def set_gguf_parameters(self):
  718. hidden_dim = self.hparams["n_embd"]
  719. inner_dim = 4 * hidden_dim
  720. hidden_dim = int(2 * inner_dim / 3)
  721. multiple_of = 256
  722. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  723. block_count = self.hparams["n_layer"]
  724. self.gguf_writer.add_name("Refact")
  725. # refact uses Alibi. So this is from config.json which might be used by training.
  726. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  727. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  728. self.gguf_writer.add_feed_forward_length(ff_dim)
  729. self.gguf_writer.add_block_count(block_count)
  730. self.gguf_writer.add_head_count(self.hparams["n_head"])
  731. self.gguf_writer.add_head_count_kv(1)
  732. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  733. self.gguf_writer.add_file_type(self.ftype)
  734. def write_tensors(self):
  735. hidden_dim = self.hparams["n_embd"]
  736. inner_dim = 4 * hidden_dim
  737. hidden_dim = int(2 * inner_dim / 3)
  738. multiple_of = 256
  739. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  740. n_head = self.hparams["n_head"]
  741. n_head_kv = 1
  742. head_dim = self.hparams["n_embd"] // n_head
  743. block_count = self.hparams["n_layer"]
  744. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  745. tensors = dict(self.get_tensors())
  746. for i in range(block_count):
  747. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  748. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  749. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  750. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  751. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  752. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  753. del tensors[f"transformer.h.{i}.attn.q.weight"]
  754. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  755. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  756. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  757. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  758. for name, data_torch in tensors.items():
  759. old_dtype = data_torch.dtype
  760. # convert any unsupported data types to float32
  761. if data_torch.dtype not in (torch.float16, torch.float32):
  762. data_torch = data_torch.to(torch.float32)
  763. data = data_torch.squeeze().numpy()
  764. # map tensor names
  765. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  766. if new_name is None:
  767. print(f"Can not map tensor {name!r}")
  768. sys.exit()
  769. n_dims = len(data.shape)
  770. data_dtype = data.dtype
  771. # if f32 desired, convert any float16 to float32
  772. if self.ftype == 0 and data_dtype == np.float16:
  773. data = data.astype(np.float32)
  774. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  775. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  776. data = data.astype(np.float32)
  777. # if f16 desired, convert any float32 2-dim weight tensors to float16
  778. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  779. data = data.astype(np.float16)
  780. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  781. self.gguf_writer.add_tensor(new_name, data)
  782. class PersimmonModel(Model):
  783. def set_gguf_parameters(self):
  784. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  785. head_count = self.hparams["num_attention_heads"]
  786. head_count_kv = head_count
  787. hidden_size = self.hparams["hidden_size"]
  788. self.gguf_writer.add_name('persimmon-8b-chat')
  789. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  790. self.gguf_writer.add_embedding_length(hidden_size)
  791. self.gguf_writer.add_block_count(block_count)
  792. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  793. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  794. # than the head size?
  795. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  796. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  797. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  798. self.gguf_writer.add_head_count(head_count)
  799. self.gguf_writer.add_head_count_kv(head_count_kv)
  800. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  801. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  802. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  803. def set_vocab(self):
  804. self._set_vocab_sentencepiece()
  805. # self.gguf_writer.add_bos_token_id(71013)
  806. # self.gguf_writer.add_eos_token_id(71013)
  807. def write_tensors(self):
  808. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  809. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  810. for name, data_torch in self.get_tensors():
  811. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  812. continue
  813. old_dtype = data_torch.dtype
  814. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  815. data = data_torch.to(torch.float32).squeeze().numpy()
  816. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  817. if new_name is None:
  818. print(f"Can not map tensor {name!r}")
  819. sys.exit()
  820. n_dims = len(data.shape)
  821. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  822. self.gguf_writer.add_tensor(new_name, data)
  823. class StableLMModel(Model):
  824. def set_vocab(self):
  825. if (self.dir_model / "tokenizer.json").is_file():
  826. self._set_vocab_gpt2()
  827. else:
  828. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  829. self._set_vocab_qwen()
  830. def set_gguf_parameters(self):
  831. hparams = self.hparams
  832. block_count = hparams["num_hidden_layers"]
  833. self.gguf_writer.add_name(self.dir_model.name)
  834. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  835. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  836. self.gguf_writer.add_block_count(block_count)
  837. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  838. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  839. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  840. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  841. self.gguf_writer.add_layer_norm_eps(1e-5)
  842. class MixtralModel(Model):
  843. def set_vocab(self):
  844. self._set_vocab_sentencepiece()
  845. class QwenModel(Model):
  846. @staticmethod
  847. def token_bytes_to_string(b):
  848. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  849. byte_encoder = bytes_to_unicode()
  850. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  851. @staticmethod
  852. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  853. parts = [bytes([b]) for b in token]
  854. while True:
  855. min_idx = None
  856. min_rank = None
  857. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  858. rank = mergeable_ranks.get(pair[0] + pair[1])
  859. if rank is not None and (min_rank is None or rank < min_rank):
  860. min_idx = i
  861. min_rank = rank
  862. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  863. break
  864. assert min_idx is not None
  865. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  866. return parts
  867. def set_vocab(self):
  868. self._set_vocab_qwen()
  869. def set_gguf_parameters(self):
  870. self.gguf_writer.add_name("Qwen")
  871. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  872. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  873. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  874. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  875. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  876. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  877. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  878. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  879. def write_tensors(self):
  880. block_count = self.hparams["num_hidden_layers"]
  881. model_kv = dict(self.get_tensors())
  882. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  883. for name, data_torch in model_kv.items():
  884. # we don't need these
  885. if name.endswith(".rotary_emb.inv_freq"):
  886. continue
  887. old_dtype = data_torch.dtype
  888. # convert any unsupported data types to float32
  889. if data_torch.dtype not in (torch.float16, torch.float32):
  890. data_torch = data_torch.to(torch.float32)
  891. data = data_torch.squeeze().numpy()
  892. # map tensor names
  893. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  894. if new_name is None:
  895. print(f"Can not map tensor {name!r}")
  896. sys.exit()
  897. n_dims = len(data.shape)
  898. data_dtype = data.dtype
  899. # if f32 desired, convert any float16 to float32
  900. if self.ftype == 0 and data_dtype == np.float16:
  901. data = data.astype(np.float32)
  902. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  903. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  904. data = data.astype(np.float32)
  905. # if f16 desired, convert any float32 2-dim weight tensors to float16
  906. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  907. data = data.astype(np.float16)
  908. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  909. self.gguf_writer.add_tensor(new_name, data)
  910. class GPT2Model(Model):
  911. def set_gguf_parameters(self):
  912. self.gguf_writer.add_name(self.dir_model.name)
  913. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  914. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  915. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  916. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  917. self.gguf_writer.add_head_count(self.hparams["n_head"])
  918. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  919. self.gguf_writer.add_file_type(self.ftype)
  920. def write_tensors(self):
  921. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  922. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  923. for name, data_torch in self.get_tensors():
  924. # we don't need these
  925. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias")):
  926. continue
  927. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  928. data_torch = data_torch.transpose(1, 0)
  929. old_dtype = data_torch.dtype
  930. # convert any unsupported data types to float32
  931. if data_torch.dtype not in (torch.float16, torch.float32):
  932. data_torch = data_torch.to(torch.float32)
  933. data = data_torch.squeeze().numpy()
  934. # map tensor names
  935. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  936. if new_name is None:
  937. print(f"Can not map tensor {name!r}")
  938. sys.exit()
  939. n_dims = len(data.shape)
  940. data_dtype = data.dtype
  941. # if f32 desired, convert any float16 to float32
  942. if self.ftype == 0 and data_dtype == np.float16:
  943. data = data.astype(np.float32)
  944. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  945. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  946. data = data.astype(np.float32)
  947. # if f16 desired, convert any float32 2-dim weight tensors to float16
  948. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  949. data = data.astype(np.float16)
  950. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  951. self.gguf_writer.add_tensor(new_name, data)
  952. # note: GPT2 output is tied to (same as) wte in original model
  953. if new_name == "token_embd.weight":
  954. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  955. self.gguf_writer.add_tensor("output.weight", data)
  956. class Phi2Model(Model):
  957. def set_gguf_parameters(self):
  958. block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
  959. rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
  960. n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
  961. n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
  962. self.gguf_writer.add_name("Phi2")
  963. self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
  964. self.gguf_writer.add_embedding_length(n_embd)
  965. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  966. self.gguf_writer.add_block_count(block_count)
  967. self.gguf_writer.add_head_count(n_head)
  968. self.gguf_writer.add_head_count_kv(n_head)
  969. self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
  970. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  971. self.gguf_writer.add_file_type(self.ftype)
  972. self.gguf_writer.add_add_bos_token(False)
  973. class PlamoModel(Model):
  974. def set_vocab(self):
  975. self._set_vocab_sentencepiece()
  976. def set_gguf_parameters(self):
  977. hparams = self.hparams
  978. block_count = hparams["num_hidden_layers"]
  979. self.gguf_writer.add_name("PLaMo")
  980. self.gguf_writer.add_context_length(4096) # not in config.json
  981. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  982. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  983. self.gguf_writer.add_block_count(block_count)
  984. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  985. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  986. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  987. def shuffle_attn_q_weight(self, data_torch):
  988. assert data_torch.size() == (5120, 5120)
  989. data_torch = data_torch.reshape(8, 5, 128, 5120)
  990. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  991. data_torch = torch.reshape(data_torch, (5120, 5120))
  992. return data_torch
  993. def shuffle_attn_output_weight(self, data_torch):
  994. assert data_torch.size() == (5120, 5120)
  995. data_torch = data_torch.reshape(5120, 8, 5, 128)
  996. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  997. data_torch = torch.reshape(data_torch, (5120, 5120))
  998. return data_torch
  999. def write_tensors(self):
  1000. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  1001. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1002. for name, data_torch in self.get_tensors():
  1003. if "self_attn.rotary_emb.inv_freq" in name:
  1004. continue
  1005. # map tensor names
  1006. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1007. if new_name is None:
  1008. print(f"Can not map tensor {name!r}")
  1009. sys.exit()
  1010. # shuffle for broadcasting of gqa in ggml_mul_mat
  1011. if new_name.endswith("attn_q.weight"):
  1012. data_torch = self.shuffle_attn_q_weight(data_torch)
  1013. elif new_name.endswith("attn_output.weight"):
  1014. data_torch = self.shuffle_attn_output_weight(data_torch)
  1015. old_dtype = data_torch.dtype
  1016. # convert any unsupported data types to float32
  1017. if data_torch.dtype not in (torch.float16, torch.float32):
  1018. data_torch = data_torch.to(torch.float32)
  1019. data = data_torch.squeeze().numpy()
  1020. n_dims = len(data.shape)
  1021. data_dtype = data.dtype
  1022. # if f32 desired, convert any float16 to float32
  1023. if self.ftype == 0 and data_dtype == np.float16:
  1024. data = data.astype(np.float32)
  1025. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1026. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1027. data = data.astype(np.float32)
  1028. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1029. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1030. data = data.astype(np.float16)
  1031. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1032. self.gguf_writer.add_tensor(new_name, data)
  1033. class CodeShellModel(Model):
  1034. def set_gguf_parameters(self):
  1035. block_count = self.hparams["n_layer"]
  1036. self.gguf_writer.add_name("CodeShell")
  1037. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1038. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1039. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1040. self.gguf_writer.add_block_count(block_count)
  1041. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1042. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1043. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1044. self.gguf_writer.add_file_type(self.ftype)
  1045. self.gguf_writer.add_rope_freq_base(10000.0)
  1046. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1047. self.gguf_writer.add_rope_scaling_factor(1.0)
  1048. def write_tensors(self):
  1049. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1050. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1051. tensors = dict(self.get_tensors())
  1052. has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
  1053. for name, data_torch in tensors.items():
  1054. # we don't need these
  1055. if name.endswith((".attn.rotary_emb.inv_freq")):
  1056. continue
  1057. old_dtype = data_torch.dtype
  1058. # convert any unsupported data types to float32
  1059. if data_torch.dtype not in (torch.float16, torch.float32):
  1060. data_torch = data_torch.to(torch.float32)
  1061. data = data_torch.squeeze().numpy()
  1062. # map tensor names
  1063. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1064. if new_name is None:
  1065. print(f"Can not map tensor {name!r}")
  1066. sys.exit()
  1067. n_dims = len(data.shape)
  1068. data_dtype = data.dtype
  1069. # if f32 desired, convert any float16 to float32
  1070. if self.ftype == 0 and data_dtype == np.float16:
  1071. data = data.astype(np.float32)
  1072. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1073. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1074. data = data.astype(np.float32)
  1075. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1076. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1077. data = data.astype(np.float16)
  1078. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1079. self.gguf_writer.add_tensor(new_name, data)
  1080. if not has_lm_head and name == "transformer.wte.weight":
  1081. self.gguf_writer.add_tensor("output.weight", data)
  1082. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  1083. ###### CONVERSION LOGIC ######
  1084. def parse_args() -> argparse.Namespace:
  1085. parser = argparse.ArgumentParser(
  1086. description="Convert a huggingface model to a GGML compatible file")
  1087. parser.add_argument(
  1088. "--vocab-only", action="store_true",
  1089. help="extract only the vocab",
  1090. )
  1091. parser.add_argument(
  1092. "--awq-path", type=Path, default=None,
  1093. help="Path to scale awq cache file")
  1094. parser.add_argument(
  1095. "--outfile", type=Path,
  1096. help="path to write to; default: based on input",
  1097. )
  1098. parser.add_argument(
  1099. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  1100. help="output format - use f32 for float32, f16 for float16",
  1101. )
  1102. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  1103. parser.add_argument(
  1104. "model", type=Path,
  1105. help="directory containing model file",
  1106. )
  1107. return parser.parse_args()
  1108. def main() -> None:
  1109. args = parse_args()
  1110. dir_model = args.model
  1111. if args.awq_path:
  1112. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1113. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1114. tmp_model_path = args.model / "weighted_model"
  1115. dir_model = tmp_model_path
  1116. if tmp_model_path.is_dir():
  1117. print(f"{tmp_model_path} exists as a weighted model.")
  1118. else:
  1119. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1120. print("Saving new weighted model ...")
  1121. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1122. print(f"Saved weighted model at {tmp_model_path}.")
  1123. if not dir_model.is_dir():
  1124. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  1125. sys.exit(1)
  1126. ftype_map = {
  1127. "f32": gguf.GGMLQuantizationType.F32,
  1128. "f16": gguf.GGMLQuantizationType.F16,
  1129. }
  1130. if args.outfile is not None:
  1131. fname_out = args.outfile
  1132. else:
  1133. # output in the same directory as the model by default
  1134. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1135. print(f"Loading model: {dir_model.name}")
  1136. hparams = Model.load_hparams(dir_model)
  1137. with torch.inference_mode():
  1138. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1139. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1140. print("Set model parameters")
  1141. model_instance.set_gguf_parameters()
  1142. print("Set model tokenizer")
  1143. model_instance.set_vocab()
  1144. if args.vocab_only:
  1145. print(f"Exporting model vocab to '{fname_out}'")
  1146. model_instance.write_vocab()
  1147. else:
  1148. print(f"Exporting model to '{fname_out}'")
  1149. model_instance.write()
  1150. print(f"Model successfully exported to '{fname_out}'")
  1151. if __name__ == '__main__':
  1152. main()