convert-hf-to-gguf.py 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. from convert import HfVocab
  20. ###### MODEL DEFINITIONS ######
  21. class SentencePieceTokenTypes(IntEnum):
  22. NORMAL = 1
  23. UNKNOWN = 2
  24. CONTROL = 3
  25. USER_DEFINED = 4
  26. UNUSED = 5
  27. BYTE = 6
  28. class Model:
  29. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  30. self.dir_model = dir_model
  31. self.ftype = ftype
  32. self.fname_out = fname_out
  33. self.is_big_endian = is_big_endian
  34. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  35. self.is_safetensors = self._is_model_safetensors()
  36. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  37. self.part_names = self._get_part_names()
  38. self.hparams = Model.load_hparams(self.dir_model)
  39. self.model_arch = self._get_model_architecture()
  40. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  41. self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
  42. def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
  43. key = next((k for k in keys if k in self.hparams), None)
  44. if key is not None:
  45. return self.hparams[key]
  46. if optional:
  47. return None
  48. raise KeyError(f"could not find any of: {keys}")
  49. def set_vocab(self):
  50. self._set_vocab_gpt2()
  51. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  52. for part_name in self.part_names:
  53. print(f"gguf: loading model part '{part_name}'")
  54. ctx: ContextManager[Any]
  55. if self.is_safetensors:
  56. from safetensors import safe_open
  57. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  58. else:
  59. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  60. with ctx as model_part:
  61. for name in model_part.keys():
  62. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  63. yield name, data
  64. def set_gguf_parameters(self):
  65. self.gguf_writer.add_name(self.dir_model.name)
  66. self.gguf_writer.add_block_count(self.block_count)
  67. if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
  68. self.gguf_writer.add_context_length(n_ctx)
  69. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  70. self.gguf_writer.add_embedding_length(n_embd)
  71. if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
  72. self.gguf_writer.add_feed_forward_length(n_ff)
  73. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  74. self.gguf_writer.add_head_count(n_head)
  75. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  76. self.gguf_writer.add_head_count_kv(n_head_kv)
  77. if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  78. self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
  79. if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None:
  80. self.gguf_writer.add_layer_norm_eps(f_norm_eps)
  81. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  82. self.gguf_writer.add_expert_count(n_experts)
  83. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  84. self.gguf_writer.add_expert_used_count(n_experts_used)
  85. self.gguf_writer.add_file_type(self.ftype)
  86. def write_tensors(self):
  87. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  88. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  89. for name, data_torch in self.get_tensors():
  90. # we don't need these
  91. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  92. continue
  93. old_dtype = data_torch.dtype
  94. # convert any unsupported data types to float32
  95. if data_torch.dtype not in (torch.float16, torch.float32):
  96. data_torch = data_torch.to(torch.float32)
  97. data = data_torch.squeeze().numpy()
  98. # map tensor names
  99. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  100. if new_name is None:
  101. print(f"Can not map tensor {name!r}")
  102. sys.exit()
  103. n_dims = len(data.shape)
  104. data_dtype = data.dtype
  105. # if f32 desired, convert any float16 to float32
  106. if self.ftype == 0 and data_dtype == np.float16:
  107. data = data.astype(np.float32)
  108. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  109. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  110. data = data.astype(np.float32)
  111. # if f16 desired, convert any float32 2-dim weight tensors to float16
  112. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  113. data = data.astype(np.float16)
  114. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  115. self.gguf_writer.add_tensor(new_name, data)
  116. def write(self):
  117. self.write_tensors()
  118. self.gguf_writer.write_header_to_file()
  119. self.gguf_writer.write_kv_data_to_file()
  120. self.gguf_writer.write_tensors_to_file()
  121. self.gguf_writer.close()
  122. def write_vocab(self):
  123. self.gguf_writer.write_header_to_file()
  124. self.gguf_writer.write_kv_data_to_file()
  125. self.gguf_writer.close()
  126. @staticmethod
  127. def count_model_parts(dir_model: Path, prefix: str) -> int:
  128. num_parts = 0
  129. for filename in os.listdir(dir_model):
  130. if filename.endswith(prefix):
  131. num_parts += 1
  132. return num_parts
  133. @staticmethod
  134. def load_hparams(dir_model):
  135. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  136. return json.load(f)
  137. @staticmethod
  138. def from_model_architecture(model_architecture):
  139. if model_architecture == "GPTNeoXForCausalLM":
  140. return GPTNeoXModel
  141. if model_architecture == "BloomForCausalLM":
  142. return BloomModel
  143. if model_architecture == "MPTForCausalLM":
  144. return MPTModel
  145. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  146. return BaichuanModel
  147. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  148. return FalconModel
  149. if model_architecture == "GPTBigCodeForCausalLM":
  150. return StarCoderModel
  151. if model_architecture == "GPTRefactForCausalLM":
  152. return RefactModel
  153. if model_architecture == "PersimmonForCausalLM":
  154. return PersimmonModel
  155. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  156. return StableLMModel
  157. if model_architecture == "QWenLMHeadModel":
  158. return QwenModel
  159. if model_architecture == "Qwen2ForCausalLM":
  160. return Model
  161. if model_architecture == "MixtralForCausalLM":
  162. return MixtralModel
  163. if model_architecture == "GPT2LMHeadModel":
  164. return GPT2Model
  165. if model_architecture == "PhiForCausalLM":
  166. return Phi2Model
  167. if model_architecture == "PlamoForCausalLM":
  168. return PlamoModel
  169. if model_architecture == "CodeShellForCausalLM":
  170. return CodeShellModel
  171. if model_architecture == "OrionForCausalLM":
  172. return OrionModel
  173. if model_architecture == "InternLM2ForCausalLM":
  174. return InternLM2Model
  175. if model_architecture == "MiniCPMForCausalLM":
  176. return MiniCPMModel
  177. if model_architecture == "BertModel":
  178. return BertModel
  179. if model_architecture == "NomicBertModel":
  180. return NomicBertModel
  181. if model_architecture == "GemmaForCausalLM":
  182. return GemmaModel
  183. return Model
  184. def _is_model_safetensors(self) -> bool:
  185. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  186. def _get_part_names(self):
  187. if self.is_safetensors:
  188. if self.num_parts == 1: # there's only one .safetensors file
  189. return ("model.safetensors",)
  190. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  191. if self.num_parts == 1: # there's only one .bin file
  192. return ("pytorch_model.bin",)
  193. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  194. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  195. arch = self.hparams["architectures"][0]
  196. if arch == "GPTNeoXForCausalLM":
  197. return gguf.MODEL_ARCH.GPTNEOX
  198. if arch == "BloomForCausalLM":
  199. return gguf.MODEL_ARCH.BLOOM
  200. if arch == "MPTForCausalLM":
  201. return gguf.MODEL_ARCH.MPT
  202. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  203. return gguf.MODEL_ARCH.BAICHUAN
  204. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  205. return gguf.MODEL_ARCH.FALCON
  206. if arch == "GPTBigCodeForCausalLM":
  207. return gguf.MODEL_ARCH.STARCODER
  208. if arch == "GPTRefactForCausalLM":
  209. return gguf.MODEL_ARCH.REFACT
  210. if arch == "PersimmonForCausalLM":
  211. return gguf.MODEL_ARCH.PERSIMMON
  212. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  213. return gguf.MODEL_ARCH.STABLELM
  214. if arch == "QWenLMHeadModel":
  215. return gguf.MODEL_ARCH.QWEN
  216. if arch == "Qwen2ForCausalLM":
  217. return gguf.MODEL_ARCH.QWEN2
  218. if arch == "MixtralForCausalLM":
  219. return gguf.MODEL_ARCH.LLAMA
  220. if arch == "GPT2LMHeadModel":
  221. return gguf.MODEL_ARCH.GPT2
  222. if arch == "PhiForCausalLM":
  223. return gguf.MODEL_ARCH.PHI2
  224. if arch == "PlamoForCausalLM":
  225. return gguf.MODEL_ARCH.PLAMO
  226. if arch == "CodeShellForCausalLM":
  227. return gguf.MODEL_ARCH.CODESHELL
  228. if arch == "OrionForCausalLM":
  229. return gguf.MODEL_ARCH.ORION
  230. if arch == "InternLM2ForCausalLM":
  231. return gguf.MODEL_ARCH.INTERNLM2
  232. if arch == "MiniCPMForCausalLM":
  233. return gguf.MODEL_ARCH.MINICPM
  234. if arch == "BertModel":
  235. return gguf.MODEL_ARCH.BERT
  236. if arch == "NomicBertModel":
  237. return gguf.MODEL_ARCH.NOMIC_BERT
  238. if arch == "GemmaForCausalLM":
  239. return gguf.MODEL_ARCH.GEMMA
  240. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  241. def _set_vocab_gpt2(self):
  242. dir_model = self.dir_model
  243. hparams = self.hparams
  244. tokens: list[bytearray] = []
  245. toktypes: list[int] = []
  246. from transformers import AutoTokenizer
  247. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  248. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  249. assert max(tokenizer.vocab.values()) < vocab_size
  250. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  251. added_vocab = tokenizer.get_added_vocab()
  252. for i in range(vocab_size):
  253. if i not in reverse_vocab:
  254. pad_token = f"[PAD{i}]".encode('utf-8')
  255. tokens.append(bytearray(pad_token))
  256. toktypes.append(gguf.TokenType.USER_DEFINED)
  257. elif reverse_vocab[i] in added_vocab:
  258. tokens.append(reverse_vocab[i])
  259. if tokenizer.added_tokens_decoder[i].special:
  260. toktypes.append(gguf.TokenType.CONTROL)
  261. else:
  262. toktypes.append(gguf.TokenType.USER_DEFINED)
  263. else:
  264. tokens.append(reverse_vocab[i])
  265. toktypes.append(gguf.TokenType.NORMAL)
  266. self.gguf_writer.add_tokenizer_model("gpt2")
  267. self.gguf_writer.add_token_list(tokens)
  268. self.gguf_writer.add_token_types(toktypes)
  269. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  270. special_vocab.add_to_gguf(self.gguf_writer)
  271. def _set_vocab_qwen(self):
  272. dir_model = self.dir_model
  273. hparams = self.hparams
  274. tokens: list[bytearray] = []
  275. toktypes: list[int] = []
  276. from transformers import AutoTokenizer
  277. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  278. vocab_size = hparams["vocab_size"]
  279. assert max(tokenizer.get_vocab().values()) < vocab_size
  280. merges = []
  281. vocab = {}
  282. mergeable_ranks = tokenizer.mergeable_ranks
  283. for token, rank in mergeable_ranks.items():
  284. vocab[QwenModel.token_bytes_to_string(token)] = rank
  285. if len(token) == 1:
  286. continue
  287. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  288. assert len(merged) == 2
  289. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  290. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  291. added_vocab = tokenizer.special_tokens
  292. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
  293. for i in range(vocab_size):
  294. if i not in reverse_vocab:
  295. pad_token = f"[PAD{i}]".encode("utf-8")
  296. tokens.append(bytearray(pad_token))
  297. toktypes.append(gguf.TokenType.USER_DEFINED)
  298. elif reverse_vocab[i] in added_vocab:
  299. tokens.append(reverse_vocab[i])
  300. toktypes.append(gguf.TokenType.CONTROL)
  301. else:
  302. tokens.append(reverse_vocab[i])
  303. toktypes.append(gguf.TokenType.NORMAL)
  304. self.gguf_writer.add_tokenizer_model("gpt2")
  305. self.gguf_writer.add_token_list(tokens)
  306. self.gguf_writer.add_token_types(toktypes)
  307. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  308. special_vocab.merges = merges
  309. # only add special tokens when they were not already loaded from config.json
  310. if len(special_vocab.special_token_ids) == 0:
  311. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  312. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  313. # this one is usually not in config.json anyway
  314. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  315. special_vocab.add_to_gguf(self.gguf_writer)
  316. def _set_vocab_sentencepiece(self):
  317. from sentencepiece import SentencePieceProcessor
  318. tokenizer_path = self.dir_model / 'tokenizer.model'
  319. tokens: list[bytes] = []
  320. scores: list[float] = []
  321. toktypes: list[int] = []
  322. if not tokenizer_path.is_file():
  323. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  324. sys.exit(1)
  325. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  326. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  327. for token_id in range(vocab_size):
  328. piece = tokenizer.id_to_piece(token_id)
  329. text = piece.encode("utf-8")
  330. score = tokenizer.get_score(token_id)
  331. toktype = SentencePieceTokenTypes.NORMAL
  332. if tokenizer.is_unknown(token_id):
  333. toktype = SentencePieceTokenTypes.UNKNOWN
  334. elif tokenizer.is_control(token_id):
  335. toktype = SentencePieceTokenTypes.CONTROL
  336. elif tokenizer.is_unused(token_id):
  337. toktype = SentencePieceTokenTypes.UNUSED
  338. elif tokenizer.is_byte(token_id):
  339. toktype = SentencePieceTokenTypes.BYTE
  340. tokens.append(text)
  341. scores.append(score)
  342. toktypes.append(toktype)
  343. added_tokens_file = self.dir_model / 'added_tokens.json'
  344. if added_tokens_file.is_file():
  345. with open(added_tokens_file, "r", encoding="utf-8") as f:
  346. added_tokens_json = json.load(f)
  347. for key in added_tokens_json:
  348. tokens.append(key.encode("utf-8"))
  349. scores.append(-1000.0)
  350. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  351. self.gguf_writer.add_tokenizer_model("llama")
  352. self.gguf_writer.add_token_list(tokens)
  353. self.gguf_writer.add_token_scores(scores)
  354. self.gguf_writer.add_token_types(toktypes)
  355. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  356. special_vocab.add_to_gguf(self.gguf_writer)
  357. def _set_vocab_hf(self):
  358. path = self.dir_model
  359. added_tokens_path = self.dir_model
  360. vocab = HfVocab(
  361. path, added_tokens_path if added_tokens_path.exists() else None
  362. )
  363. tokens = []
  364. scores = []
  365. toktypes = []
  366. for text, score, toktype in vocab.all_tokens():
  367. tokens.append(text)
  368. scores.append(score)
  369. toktypes.append(toktype)
  370. assert len(tokens) == vocab.vocab_size
  371. self.gguf_writer.add_tokenizer_model("llama")
  372. self.gguf_writer.add_token_list(tokens)
  373. self.gguf_writer.add_token_scores(scores)
  374. self.gguf_writer.add_token_types(toktypes)
  375. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  376. special_vocab.add_to_gguf(self.gguf_writer)
  377. class GPTNeoXModel(Model):
  378. def set_gguf_parameters(self):
  379. block_count = self.hparams["num_hidden_layers"]
  380. self.gguf_writer.add_name(self.dir_model.name)
  381. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  382. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  383. self.gguf_writer.add_block_count(block_count)
  384. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  385. self.gguf_writer.add_rope_dimension_count(
  386. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  387. )
  388. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  389. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  390. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  391. class BloomModel(Model):
  392. def set_gguf_parameters(self):
  393. self.gguf_writer.add_name("Bloom")
  394. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  395. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  396. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  397. self.gguf_writer.add_embedding_length(n_embed)
  398. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  399. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  400. self.gguf_writer.add_head_count(n_head)
  401. self.gguf_writer.add_head_count_kv(n_head)
  402. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  403. self.gguf_writer.add_file_type(self.ftype)
  404. def write_tensors(self):
  405. block_count = self.hparams["n_layer"]
  406. tensors = dict(self.get_tensors())
  407. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  408. has_lm_head = True
  409. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  410. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  411. for name, data_torch in tensors.items():
  412. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  413. has_lm_head = False
  414. name = re.sub(r'transformer\.', '', name)
  415. old_dtype = data_torch.dtype
  416. # convert any unsupported data types to float32
  417. if data_torch.dtype not in (torch.float16, torch.float32):
  418. data_torch = data_torch.to(torch.float32)
  419. data = data_torch.squeeze().numpy()
  420. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  421. # Map bloom-style qkv_linear to gpt-style qkv_linear
  422. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  423. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  424. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  425. data = np.concatenate(
  426. (
  427. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  428. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  429. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  430. ),
  431. axis=0,
  432. )
  433. print("re-format attention.linear_qkv.weight")
  434. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  435. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  436. data = np.concatenate(
  437. (
  438. qkv_bias[:, 0, :].reshape((n_embed,)),
  439. qkv_bias[:, 1, :].reshape((n_embed,)),
  440. qkv_bias[:, 2, :].reshape((n_embed,)),
  441. ),
  442. axis=0,
  443. )
  444. print("re-format attention.linear_qkv.bias")
  445. # map tensor names
  446. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  447. if new_name is None:
  448. print(f"Can not map tensor {name!r}")
  449. sys.exit()
  450. n_dims = len(data.shape)
  451. data_dtype = data.dtype
  452. # if f32 desired, convert any float16 to float32
  453. if self.ftype == 0 and data_dtype == np.float16:
  454. data = data.astype(np.float32)
  455. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  456. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  457. data = data.astype(np.float32)
  458. # if f16 desired, convert any float32 2-dim weight tensors to float16
  459. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  460. data = data.astype(np.float16)
  461. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  462. self.gguf_writer.add_tensor(new_name, data)
  463. if not has_lm_head and name == "word_embeddings.weight":
  464. self.gguf_writer.add_tensor("output.weight", data)
  465. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  466. class MPTModel(Model):
  467. def set_gguf_parameters(self):
  468. block_count = self.hparams["n_layers"]
  469. self.gguf_writer.add_name(self.dir_model.name)
  470. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  471. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  472. self.gguf_writer.add_block_count(block_count)
  473. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  474. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  475. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  476. self.gguf_writer.add_head_count_kv(kv_n_heads)
  477. self.gguf_writer.add_layer_norm_eps(1e-5)
  478. if self.hparams["attn_config"]["clip_qkv"] is not None:
  479. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  480. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  481. def write_tensors(self):
  482. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  483. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  484. for name, data_torch in self.get_tensors():
  485. # we don't need these
  486. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  487. continue
  488. old_dtype = data_torch.dtype
  489. # convert any unsupported data types to float32
  490. if data_torch.dtype not in (torch.float16, torch.float32):
  491. data_torch = data_torch.to(torch.float32)
  492. data = data_torch.squeeze().numpy()
  493. # map tensor names
  494. if "scales" in name:
  495. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  496. if new_name is not None:
  497. new_name = new_name.replace("scales", "act.scales")
  498. else:
  499. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  500. if new_name is None:
  501. print(f"Can not map tensor {name!r}")
  502. sys.exit()
  503. n_dims = len(data.shape)
  504. data_dtype = data.dtype
  505. # if f32 desired, convert any float16 to float32
  506. if self.ftype == 0 and data_dtype == np.float16:
  507. data = data.astype(np.float32)
  508. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  509. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  510. data = data.astype(np.float32)
  511. # if f16 desired, convert any float32 2-dim weight tensors to float16
  512. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  513. data = data.astype(np.float16)
  514. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  515. self.gguf_writer.add_tensor(new_name, data)
  516. class OrionModel(Model):
  517. def set_vocab(self):
  518. self._set_vocab_sentencepiece()
  519. def set_gguf_parameters(self):
  520. block_count = self.hparams["num_hidden_layers"]
  521. head_count = self.hparams["num_attention_heads"]
  522. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  523. hf_repo = self.hparams.get("_name_or_path", "")
  524. ctx_length = 0
  525. if "max_sequence_length" in self.hparams:
  526. ctx_length = self.hparams["max_sequence_length"]
  527. elif "max_position_embeddings" in self.hparams:
  528. ctx_length = self.hparams["max_position_embeddings"]
  529. elif "model_max_length" in self.hparams:
  530. ctx_length = self.hparams["model_max_length"]
  531. else:
  532. print("gguf: can not find ctx length parameter.")
  533. sys.exit()
  534. self.gguf_writer.add_file_type(self.ftype)
  535. self.gguf_writer.add_name(self.dir_model.name)
  536. self.gguf_writer.add_source_hf_repo(hf_repo)
  537. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  538. self.gguf_writer.add_context_length(ctx_length)
  539. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  540. self.gguf_writer.add_block_count(block_count)
  541. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  542. self.gguf_writer.add_head_count(head_count)
  543. self.gguf_writer.add_head_count_kv(head_count_kv)
  544. # note: config provides rms norm but it is actually layer norm
  545. # ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571
  546. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  547. def write_tensors(self):
  548. # Collect tensors from generator object
  549. model_kv = dict(self.get_tensors())
  550. block_count = self.hparams["num_hidden_layers"]
  551. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  552. for name, data_torch in model_kv.items():
  553. # we don't need these
  554. if name.endswith(".rotary_emb.inv_freq"):
  555. continue
  556. old_dtype = data_torch.dtype
  557. # convert any unsupported data types to float32
  558. if data_torch.dtype not in (torch.float16, torch.float32):
  559. data_torch = data_torch.to(torch.float32)
  560. data = data_torch.squeeze().numpy()
  561. # map tensor names
  562. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  563. if new_name is None:
  564. print(f"Can not map tensor {name!r}")
  565. sys.exit()
  566. n_dims = len(data.shape)
  567. data_dtype = data.dtype
  568. # if f32 desired, convert any float16 to float32
  569. if self.ftype == 0 and data_dtype == np.float16:
  570. data = data.astype(np.float32)
  571. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  572. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  573. data = data.astype(np.float32)
  574. # if f16 desired, convert any float32 2-dim weight tensors to float16
  575. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  576. data = data.astype(np.float16)
  577. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  578. self.gguf_writer.add_tensor(new_name, data)
  579. class BaichuanModel(Model):
  580. def set_vocab(self):
  581. self._set_vocab_sentencepiece()
  582. def set_gguf_parameters(self):
  583. block_count = self.hparams["num_hidden_layers"]
  584. head_count = self.hparams["num_attention_heads"]
  585. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  586. hf_repo = self.hparams.get("_name_or_path", "")
  587. ctx_length = 0
  588. if "max_sequence_length" in self.hparams:
  589. ctx_length = self.hparams["max_sequence_length"]
  590. elif "max_position_embeddings" in self.hparams:
  591. ctx_length = self.hparams["max_position_embeddings"]
  592. elif "model_max_length" in self.hparams:
  593. ctx_length = self.hparams["model_max_length"]
  594. else:
  595. print("gguf: can not find ctx length parameter.")
  596. sys.exit()
  597. self.gguf_writer.add_name(self.dir_model.name)
  598. self.gguf_writer.add_source_hf_repo(hf_repo)
  599. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  600. self.gguf_writer.add_context_length(ctx_length)
  601. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  602. self.gguf_writer.add_block_count(block_count)
  603. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  604. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  605. self.gguf_writer.add_head_count(head_count)
  606. self.gguf_writer.add_head_count_kv(head_count_kv)
  607. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  608. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  609. if self.hparams["rope_scaling"].get("type") == "linear":
  610. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  611. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  612. def write_tensors(self):
  613. # Collect tensors from generator object
  614. model_kv = dict(self.get_tensors())
  615. block_count = self.hparams["num_hidden_layers"]
  616. head_count = self.hparams["num_attention_heads"]
  617. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  618. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  619. for i in range(block_count):
  620. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  621. print(f"Unpacking and permuting layer {i}")
  622. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  623. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  624. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  625. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  626. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  627. self._reverse_hf_part(w, 2)
  628. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  629. for name, data_torch in model_kv.items():
  630. # we don't need these
  631. if name.endswith(".rotary_emb.inv_freq"):
  632. continue
  633. old_dtype = data_torch.dtype
  634. # convert any unsupported data types to float32
  635. if data_torch.dtype not in (torch.float16, torch.float32):
  636. data_torch = data_torch.to(torch.float32)
  637. data = data_torch.squeeze().numpy()
  638. # map tensor names
  639. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  640. if new_name is None:
  641. print(f"Can not map tensor {name!r}")
  642. sys.exit()
  643. n_dims = len(data.shape)
  644. data_dtype = data.dtype
  645. # if f32 desired, convert any float16 to float32
  646. if self.ftype == 0 and data_dtype == np.float16:
  647. data = data.astype(np.float32)
  648. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  649. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  650. data = data.astype(np.float32)
  651. # if f16 desired, convert any float32 2-dim weight tensors to float16
  652. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  653. data = data.astype(np.float16)
  654. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  655. self.gguf_writer.add_tensor(new_name, data)
  656. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  657. if n_kv_head is not None and n_head != n_kv_head:
  658. n_head //= n_kv_head
  659. return (
  660. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  661. .swapaxes(1, 2)
  662. .reshape(weights.shape)
  663. )
  664. def _reverse_hf_permute_part(
  665. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  666. ) -> Tensor:
  667. r = weights.shape[0] // 3
  668. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  669. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  670. r = weights.shape[0] // 3
  671. return weights[r * n_part:r * n_part + r, ...]
  672. class FalconModel(Model):
  673. def set_gguf_parameters(self):
  674. block_count = self.hparams.get("num_hidden_layers")
  675. if block_count is None:
  676. block_count = self.hparams["n_layer"] # old name
  677. n_head = self.hparams.get("num_attention_heads")
  678. if n_head is None:
  679. n_head = self.hparams["n_head"] # old name
  680. n_head_kv = self.hparams.get("num_kv_heads")
  681. if n_head_kv is None:
  682. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  683. self.gguf_writer.add_name("Falcon")
  684. self.gguf_writer.add_context_length(2048) # not in config.json
  685. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  686. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  687. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  688. self.gguf_writer.add_block_count(block_count)
  689. self.gguf_writer.add_head_count(n_head)
  690. self.gguf_writer.add_head_count_kv(n_head_kv)
  691. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  692. self.gguf_writer.add_file_type(self.ftype)
  693. def write_tensors(self):
  694. block_count = self.hparams.get("num_hidden_layers")
  695. if block_count is None:
  696. block_count = self.hparams["n_layer"] # old name
  697. n_head = self.hparams.get("num_attention_heads")
  698. if n_head is None:
  699. n_head = self.hparams["n_head"] # old name
  700. n_head_kv = self.hparams.get("num_kv_heads")
  701. if n_head_kv is None:
  702. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  703. head_dim = self.hparams["hidden_size"] // n_head
  704. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  705. for name, data_torch in self.get_tensors():
  706. old_dtype = data_torch.dtype
  707. # convert any unsupported data types to float32
  708. if data_torch.dtype not in (torch.float16, torch.float32):
  709. data_torch = data_torch.to(torch.float32)
  710. # QKV tensor transform
  711. # The original query_key_value tensor contains n_head_kv "kv groups",
  712. # each consisting of n_head/n_head_kv query weights followed by one key
  713. # and one value weight (shared by all query heads in the kv group).
  714. # This layout makes it a big pain to work with in GGML.
  715. # So we rearrange them here,, so that we have n_head query weights
  716. # followed by n_head_kv key weights followed by n_head_kv value weights,
  717. # in contiguous fashion.
  718. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  719. if "query_key_value" in name:
  720. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  721. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  722. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  723. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  724. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  725. data = data_torch.squeeze().numpy()
  726. # map tensor names
  727. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  728. if new_name is None:
  729. print(f"Can not map tensor {name!r}")
  730. sys.exit()
  731. n_dims = len(data.shape)
  732. data_dtype = data.dtype
  733. # if f32 desired, convert any float16 to float32
  734. if self.ftype == 0 and data_dtype == np.float16:
  735. data = data.astype(np.float32)
  736. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  737. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  738. data = data.astype(np.float32)
  739. # if f16 desired, convert any float32 2-dim weight tensors to float16
  740. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  741. data = data.astype(np.float16)
  742. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  743. self.gguf_writer.add_tensor(new_name, data)
  744. class StarCoderModel(Model):
  745. def set_gguf_parameters(self):
  746. block_count = self.hparams["n_layer"]
  747. self.gguf_writer.add_name("StarCoder")
  748. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  749. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  750. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  751. self.gguf_writer.add_block_count(block_count)
  752. self.gguf_writer.add_head_count(self.hparams["n_head"])
  753. self.gguf_writer.add_head_count_kv(1)
  754. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  755. self.gguf_writer.add_file_type(self.ftype)
  756. class RefactModel(Model):
  757. def set_gguf_parameters(self):
  758. hidden_dim = self.hparams["n_embd"]
  759. inner_dim = 4 * hidden_dim
  760. hidden_dim = int(2 * inner_dim / 3)
  761. multiple_of = 256
  762. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  763. block_count = self.hparams["n_layer"]
  764. self.gguf_writer.add_name("Refact")
  765. # refact uses Alibi. So this is from config.json which might be used by training.
  766. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  767. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  768. self.gguf_writer.add_feed_forward_length(ff_dim)
  769. self.gguf_writer.add_block_count(block_count)
  770. self.gguf_writer.add_head_count(self.hparams["n_head"])
  771. self.gguf_writer.add_head_count_kv(1)
  772. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  773. self.gguf_writer.add_file_type(self.ftype)
  774. def write_tensors(self):
  775. hidden_dim = self.hparams["n_embd"]
  776. inner_dim = 4 * hidden_dim
  777. hidden_dim = int(2 * inner_dim / 3)
  778. multiple_of = 256
  779. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  780. n_head = self.hparams["n_head"]
  781. n_head_kv = 1
  782. head_dim = self.hparams["n_embd"] // n_head
  783. block_count = self.hparams["n_layer"]
  784. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  785. tensors = dict(self.get_tensors())
  786. for i in range(block_count):
  787. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  788. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  789. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  790. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  791. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  792. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  793. del tensors[f"transformer.h.{i}.attn.q.weight"]
  794. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  795. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  796. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  797. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  798. for name, data_torch in tensors.items():
  799. old_dtype = data_torch.dtype
  800. # convert any unsupported data types to float32
  801. if data_torch.dtype not in (torch.float16, torch.float32):
  802. data_torch = data_torch.to(torch.float32)
  803. data = data_torch.squeeze().numpy()
  804. # map tensor names
  805. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  806. if new_name is None:
  807. print(f"Can not map tensor {name!r}")
  808. sys.exit()
  809. n_dims = len(data.shape)
  810. data_dtype = data.dtype
  811. # if f32 desired, convert any float16 to float32
  812. if self.ftype == 0 and data_dtype == np.float16:
  813. data = data.astype(np.float32)
  814. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  815. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  816. data = data.astype(np.float32)
  817. # if f16 desired, convert any float32 2-dim weight tensors to float16
  818. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  819. data = data.astype(np.float16)
  820. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  821. self.gguf_writer.add_tensor(new_name, data)
  822. class PersimmonModel(Model):
  823. def set_gguf_parameters(self):
  824. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  825. head_count = self.hparams["num_attention_heads"]
  826. head_count_kv = head_count
  827. hidden_size = self.hparams["hidden_size"]
  828. self.gguf_writer.add_name('persimmon-8b-chat')
  829. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  830. self.gguf_writer.add_embedding_length(hidden_size)
  831. self.gguf_writer.add_block_count(block_count)
  832. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  833. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  834. # than the head size?
  835. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  836. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  837. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  838. self.gguf_writer.add_head_count(head_count)
  839. self.gguf_writer.add_head_count_kv(head_count_kv)
  840. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  841. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  842. def set_vocab(self):
  843. self._set_vocab_sentencepiece()
  844. # self.gguf_writer.add_bos_token_id(71013)
  845. # self.gguf_writer.add_eos_token_id(71013)
  846. def write_tensors(self):
  847. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  848. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  849. for name, data_torch in self.get_tensors():
  850. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  851. continue
  852. old_dtype = data_torch.dtype
  853. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  854. data = data_torch.to(torch.float32).squeeze().numpy()
  855. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  856. if new_name is None:
  857. print(f"Can not map tensor {name!r}")
  858. sys.exit()
  859. n_dims = len(data.shape)
  860. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  861. self.gguf_writer.add_tensor(new_name, data)
  862. class StableLMModel(Model):
  863. def set_vocab(self):
  864. if (self.dir_model / "tokenizer.json").is_file():
  865. self._set_vocab_gpt2()
  866. else:
  867. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  868. self._set_vocab_qwen()
  869. def set_gguf_parameters(self):
  870. hparams = self.hparams
  871. block_count = hparams["num_hidden_layers"]
  872. self.gguf_writer.add_name(self.dir_model.name)
  873. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  874. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  875. self.gguf_writer.add_block_count(block_count)
  876. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  877. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  878. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  879. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  880. self.gguf_writer.add_layer_norm_eps(1e-5)
  881. class MixtralModel(Model):
  882. def set_vocab(self):
  883. self._set_vocab_sentencepiece()
  884. class MiniCPMModel(Model):
  885. def set_gguf_parameters(self):
  886. block_count = self.hparams["num_hidden_layers"]
  887. self.gguf_writer.add_name("MiniCPM")
  888. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  889. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  890. self.gguf_writer.add_block_count(block_count)
  891. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  892. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  893. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  894. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  895. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  896. self.gguf_writer.add_file_type(self.ftype)
  897. def set_vocab(self):
  898. self._set_vocab_hf()
  899. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  900. if n_kv_head is not None and n_head != n_kv_head:
  901. n_head //= n_kv_head
  902. return (
  903. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  904. .swapaxes(1, 2)
  905. .reshape(weights.shape)
  906. )
  907. def write_tensors(self):
  908. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  909. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  910. n_head = self.hparams.get("num_attention_heads")
  911. n_kv_head = self.hparams.get("num_key_value_heads")
  912. for name, data_torch in self.get_tensors():
  913. # we don't need these
  914. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  915. continue
  916. old_dtype = data_torch.dtype
  917. # convert any unsupported data types to float32
  918. if data_torch.dtype not in (torch.float16, torch.float32):
  919. data_torch = data_torch.to(torch.float32)
  920. # HF models permute some of the tensors, so we need to undo that
  921. if name.endswith(("q_proj.weight")):
  922. data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
  923. if name.endswith(("k_proj.weight")):
  924. data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
  925. data = data_torch.squeeze().numpy()
  926. # map tensor names
  927. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  928. if new_name is None:
  929. print(f"Can not map tensor {name!r}")
  930. sys.exit()
  931. n_dims = len(data.shape)
  932. data_dtype = data.dtype
  933. # if f32 desired, convert any float16 to float32
  934. if self.ftype == 0 and data_dtype == np.float16:
  935. data = data.astype(np.float32)
  936. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  937. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  938. data = data.astype(np.float32)
  939. # if f16 desired, convert any float32 2-dim weight tensors to float16
  940. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  941. data = data.astype(np.float16)
  942. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  943. self.gguf_writer.add_tensor(new_name, data)
  944. class QwenModel(Model):
  945. @staticmethod
  946. def token_bytes_to_string(b):
  947. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  948. byte_encoder = bytes_to_unicode()
  949. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  950. @staticmethod
  951. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  952. parts = [bytes([b]) for b in token]
  953. while True:
  954. min_idx = None
  955. min_rank = None
  956. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  957. rank = mergeable_ranks.get(pair[0] + pair[1])
  958. if rank is not None and (min_rank is None or rank < min_rank):
  959. min_idx = i
  960. min_rank = rank
  961. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  962. break
  963. assert min_idx is not None
  964. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  965. return parts
  966. def set_vocab(self):
  967. self._set_vocab_qwen()
  968. def set_gguf_parameters(self):
  969. self.gguf_writer.add_name("Qwen")
  970. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  971. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  972. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  973. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  974. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  975. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  976. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  977. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  978. def write_tensors(self):
  979. block_count = self.hparams["num_hidden_layers"]
  980. model_kv = dict(self.get_tensors())
  981. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  982. for name, data_torch in model_kv.items():
  983. # we don't need these
  984. if name.endswith(".rotary_emb.inv_freq"):
  985. continue
  986. old_dtype = data_torch.dtype
  987. # convert any unsupported data types to float32
  988. if data_torch.dtype not in (torch.float16, torch.float32):
  989. data_torch = data_torch.to(torch.float32)
  990. data = data_torch.squeeze().numpy()
  991. # map tensor names
  992. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  993. if new_name is None:
  994. print(f"Can not map tensor {name!r}")
  995. sys.exit()
  996. n_dims = len(data.shape)
  997. data_dtype = data.dtype
  998. # if f32 desired, convert any float16 to float32
  999. if self.ftype == 0 and data_dtype == np.float16:
  1000. data = data.astype(np.float32)
  1001. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1002. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1003. data = data.astype(np.float32)
  1004. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1005. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1006. data = data.astype(np.float16)
  1007. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1008. self.gguf_writer.add_tensor(new_name, data)
  1009. class GPT2Model(Model):
  1010. def set_gguf_parameters(self):
  1011. self.gguf_writer.add_name(self.dir_model.name)
  1012. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1013. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  1014. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1015. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1016. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1017. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1018. self.gguf_writer.add_file_type(self.ftype)
  1019. def write_tensors(self):
  1020. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1021. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1022. for name, data_torch in self.get_tensors():
  1023. # we don't need these
  1024. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")):
  1025. continue
  1026. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  1027. data_torch = data_torch.transpose(1, 0)
  1028. old_dtype = data_torch.dtype
  1029. # convert any unsupported data types to float32
  1030. if data_torch.dtype not in (torch.float16, torch.float32):
  1031. data_torch = data_torch.to(torch.float32)
  1032. data = data_torch.squeeze().numpy()
  1033. # map tensor names
  1034. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1035. if new_name is None:
  1036. print(f"Can not map tensor {name!r}")
  1037. sys.exit()
  1038. n_dims = len(data.shape)
  1039. data_dtype = data.dtype
  1040. # if f32 desired, convert any float16 to float32
  1041. if self.ftype == 0 and data_dtype == np.float16:
  1042. data = data.astype(np.float32)
  1043. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1044. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1045. data = data.astype(np.float32)
  1046. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1047. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1048. data = data.astype(np.float16)
  1049. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1050. self.gguf_writer.add_tensor(new_name, data)
  1051. # note: GPT2 output is tied to (same as) wte in original model
  1052. if new_name == "token_embd.weight":
  1053. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1054. self.gguf_writer.add_tensor("output.weight", data)
  1055. class Phi2Model(Model):
  1056. def set_gguf_parameters(self):
  1057. block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
  1058. rot_pct = self.find_hparam(["partial_rotary_factor"])
  1059. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  1060. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  1061. self.gguf_writer.add_name("Phi2")
  1062. self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
  1063. self.gguf_writer.add_embedding_length(n_embd)
  1064. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  1065. self.gguf_writer.add_block_count(block_count)
  1066. self.gguf_writer.add_head_count(n_head)
  1067. self.gguf_writer.add_head_count_kv(n_head)
  1068. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
  1069. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1070. self.gguf_writer.add_file_type(self.ftype)
  1071. self.gguf_writer.add_add_bos_token(False)
  1072. class PlamoModel(Model):
  1073. def set_vocab(self):
  1074. self._set_vocab_sentencepiece()
  1075. def set_gguf_parameters(self):
  1076. hparams = self.hparams
  1077. block_count = hparams["num_hidden_layers"]
  1078. self.gguf_writer.add_name("PLaMo")
  1079. self.gguf_writer.add_context_length(4096) # not in config.json
  1080. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1081. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1082. self.gguf_writer.add_block_count(block_count)
  1083. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1084. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  1085. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  1086. def shuffle_attn_q_weight(self, data_torch):
  1087. assert data_torch.size() == (5120, 5120)
  1088. data_torch = data_torch.reshape(8, 5, 128, 5120)
  1089. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  1090. data_torch = torch.reshape(data_torch, (5120, 5120))
  1091. return data_torch
  1092. def shuffle_attn_output_weight(self, data_torch):
  1093. assert data_torch.size() == (5120, 5120)
  1094. data_torch = data_torch.reshape(5120, 8, 5, 128)
  1095. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  1096. data_torch = torch.reshape(data_torch, (5120, 5120))
  1097. return data_torch
  1098. def write_tensors(self):
  1099. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  1100. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1101. for name, data_torch in self.get_tensors():
  1102. if "self_attn.rotary_emb.inv_freq" in name:
  1103. continue
  1104. # map tensor names
  1105. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1106. if new_name is None:
  1107. print(f"Can not map tensor {name!r}")
  1108. sys.exit()
  1109. # shuffle for broadcasting of gqa in ggml_mul_mat
  1110. if new_name.endswith("attn_q.weight"):
  1111. data_torch = self.shuffle_attn_q_weight(data_torch)
  1112. elif new_name.endswith("attn_output.weight"):
  1113. data_torch = self.shuffle_attn_output_weight(data_torch)
  1114. old_dtype = data_torch.dtype
  1115. # convert any unsupported data types to float32
  1116. if data_torch.dtype not in (torch.float16, torch.float32):
  1117. data_torch = data_torch.to(torch.float32)
  1118. data = data_torch.squeeze().numpy()
  1119. n_dims = len(data.shape)
  1120. data_dtype = data.dtype
  1121. # if f32 desired, convert any float16 to float32
  1122. if self.ftype == 0 and data_dtype == np.float16:
  1123. data = data.astype(np.float32)
  1124. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1125. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1126. data = data.astype(np.float32)
  1127. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1128. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1129. data = data.astype(np.float16)
  1130. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1131. self.gguf_writer.add_tensor(new_name, data)
  1132. class CodeShellModel(Model):
  1133. def set_gguf_parameters(self):
  1134. block_count = self.hparams["n_layer"]
  1135. self.gguf_writer.add_name("CodeShell")
  1136. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1137. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1138. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1139. self.gguf_writer.add_block_count(block_count)
  1140. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1141. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1142. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1143. self.gguf_writer.add_file_type(self.ftype)
  1144. self.gguf_writer.add_rope_freq_base(10000.0)
  1145. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1146. self.gguf_writer.add_rope_scaling_factor(1.0)
  1147. def write_tensors(self):
  1148. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1149. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1150. tensors = dict(self.get_tensors())
  1151. has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
  1152. for name, data_torch in tensors.items():
  1153. # we don't need these
  1154. if name.endswith((".attn.rotary_emb.inv_freq")):
  1155. continue
  1156. old_dtype = data_torch.dtype
  1157. # convert any unsupported data types to float32
  1158. if data_torch.dtype not in (torch.float16, torch.float32):
  1159. data_torch = data_torch.to(torch.float32)
  1160. data = data_torch.squeeze().numpy()
  1161. # map tensor names
  1162. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1163. if new_name is None:
  1164. print(f"Can not map tensor {name!r}")
  1165. sys.exit()
  1166. n_dims = len(data.shape)
  1167. data_dtype = data.dtype
  1168. # if f32 desired, convert any float16 to float32
  1169. if self.ftype == 0 and data_dtype == np.float16:
  1170. data = data.astype(np.float32)
  1171. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1172. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1173. data = data.astype(np.float32)
  1174. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1175. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1176. data = data.astype(np.float16)
  1177. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1178. self.gguf_writer.add_tensor(new_name, data)
  1179. if not has_lm_head and name == "transformer.wte.weight":
  1180. self.gguf_writer.add_tensor("output.weight", data)
  1181. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  1182. class InternLM2Model(Model):
  1183. def set_vocab(self):
  1184. # (TODO): Is there a better way?
  1185. # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
  1186. # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
  1187. # recognized as an empty string in C++.
  1188. from sentencepiece import SentencePieceProcessor
  1189. from sentencepiece import sentencepiece_model_pb2 as model
  1190. tokenizer_path = self.dir_model / 'tokenizer.model'
  1191. tokens: list[bytes] = []
  1192. scores: list[float] = []
  1193. toktypes: list[int] = []
  1194. if not tokenizer_path.is_file():
  1195. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  1196. sys.exit(1)
  1197. sentencepiece_model = model.ModelProto()
  1198. sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
  1199. add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
  1200. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  1201. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1202. for token_id in range(vocab_size):
  1203. piece = tokenizer.id_to_piece(token_id)
  1204. text = piece.encode("utf-8")
  1205. score = tokenizer.get_score(token_id)
  1206. if text == b"\x00":
  1207. # (TODO): fixme
  1208. # Hack here and replace the \x00 characters.
  1209. print(f"InternLM2 convert token '{text}' to '🐉'!")
  1210. text = "🐉"
  1211. toktype = SentencePieceTokenTypes.NORMAL
  1212. if tokenizer.is_unknown(token_id):
  1213. toktype = SentencePieceTokenTypes.UNKNOWN
  1214. elif tokenizer.is_control(token_id):
  1215. toktype = SentencePieceTokenTypes.CONTROL
  1216. elif tokenizer.is_unused(token_id):
  1217. toktype = SentencePieceTokenTypes.UNUSED
  1218. elif tokenizer.is_byte(token_id):
  1219. toktype = SentencePieceTokenTypes.BYTE
  1220. tokens.append(text)
  1221. scores.append(score)
  1222. toktypes.append(toktype)
  1223. added_tokens_file = self.dir_model / 'added_tokens.json'
  1224. if added_tokens_file.is_file():
  1225. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1226. added_tokens_json = json.load(f)
  1227. for key in added_tokens_json:
  1228. tokens.append(key.encode("utf-8"))
  1229. scores.append(-1000.0)
  1230. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  1231. self.gguf_writer.add_tokenizer_model("llama")
  1232. self.gguf_writer.add_token_list(tokens)
  1233. self.gguf_writer.add_token_scores(scores)
  1234. self.gguf_writer.add_token_types(toktypes)
  1235. self.gguf_writer.add_add_space_prefix(add_prefix)
  1236. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1237. old_eos = special_vocab.special_token_ids["eos"]
  1238. if "chat" in os.path.basename(self.dir_model.absolute()):
  1239. # For the chat model, we replace the eos with '<|im_end|>'.
  1240. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
  1241. print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
  1242. in chat mode so that the conversation can end normally.")
  1243. special_vocab.add_to_gguf(self.gguf_writer)
  1244. def _try_get_sft_eos(self, tokenizer):
  1245. unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
  1246. im_end_list = tokenizer.encode('<|im_end|>')
  1247. assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
  1248. if len(unused_145_list) == 1:
  1249. eos_token = unused_145_list[0]
  1250. if len(im_end_list) == 1:
  1251. eos_token = im_end_list[0]
  1252. return eos_token
  1253. def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
  1254. if n_head_kv is not None and n_head != n_head_kv:
  1255. n_head = n_head_kv
  1256. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1257. .swapaxes(1, 2)
  1258. .reshape(weights.shape))
  1259. def set_gguf_parameters(self):
  1260. self.gguf_writer.add_name("InternLM2")
  1261. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1262. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1263. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1264. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1265. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  1266. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1267. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1268. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1269. def post_write_tensors(self, tensor_map, name, data_torch):
  1270. old_dtype = data_torch.dtype
  1271. # convert any unsupported data types to float32
  1272. if data_torch.dtype not in (torch.float16, torch.float32):
  1273. data_torch = data_torch.to(torch.float32)
  1274. data = data_torch.squeeze().numpy()
  1275. # map tensor names
  1276. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1277. if new_name is None:
  1278. print(f"Can not map tensor {name!r}")
  1279. sys.exit()
  1280. n_dims = len(data.shape)
  1281. data_dtype = data.dtype
  1282. # if f32 desired, convert any float16 to float32
  1283. if self.ftype == 0 and data_dtype == np.float16:
  1284. data = data.astype(np.float32)
  1285. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1286. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1287. data = data.astype(np.float32)
  1288. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1289. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1290. data = data.astype(np.float16)
  1291. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1292. self.gguf_writer.add_tensor(new_name, data)
  1293. def write_tensors(self):
  1294. from einops import rearrange
  1295. num_heads = self.hparams.get("num_attention_heads")
  1296. num_kv_heads = self.hparams.get("num_key_value_heads")
  1297. hidden_size = self.hparams.get("hidden_size")
  1298. q_per_kv = num_heads // num_kv_heads
  1299. head_dim = hidden_size // num_heads
  1300. num_groups = num_heads // q_per_kv
  1301. block_count = self.hparams["num_hidden_layers"]
  1302. model_kv = dict(self.get_tensors())
  1303. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1304. qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
  1305. for name, data_torch in model_kv.items():
  1306. # we don't need these
  1307. if name.endswith(".rotary_emb.inv_freq"):
  1308. continue
  1309. if re.match(qkv_pattern, name):
  1310. bid = re.findall(qkv_pattern, name)[0]
  1311. qkv = data_torch
  1312. qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
  1313. q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
  1314. # The model weights of q and k equire additional reshape.
  1315. q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
  1316. k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
  1317. v = rearrange(v, " o g n i -> o (g n i)").T
  1318. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
  1319. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
  1320. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
  1321. else:
  1322. self.post_write_tensors(tensor_map, name, data_torch)
  1323. class BertModel(Model):
  1324. def __init__(self, *args, **kwargs):
  1325. super().__init__(*args, **kwargs)
  1326. self.vocab_size = None
  1327. def set_gguf_parameters(self):
  1328. super().set_gguf_parameters()
  1329. self.gguf_writer.add_causal_attention(False)
  1330. # get pooling path
  1331. with open(self.dir_model / "modules.json", encoding="utf-8") as f:
  1332. modules = json.load(f)
  1333. pooling_path = None
  1334. for mod in modules:
  1335. if mod["type"] == "sentence_transformers.models.Pooling":
  1336. pooling_path = mod["path"]
  1337. break
  1338. # get pooling type
  1339. pooling_type = gguf.PoolingType.NONE
  1340. if pooling_path is not None:
  1341. with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
  1342. pooling = json.load(f)
  1343. if pooling["pooling_mode_mean_tokens"]:
  1344. pooling_type = gguf.PoolingType.MEAN
  1345. elif pooling["pooling_mode_cls_token"]:
  1346. pooling_type = gguf.PoolingType.CLS
  1347. else:
  1348. raise NotImplementedError("Only MEAN and CLS pooling types supported")
  1349. self.gguf_writer.add_pooling_type(pooling_type.value)
  1350. def set_vocab(self):
  1351. path = self.dir_model
  1352. added_tokens_path = self.dir_model if self.dir_model.exists() else None
  1353. # use huggingface vocab to get all tokens
  1354. vocab = HfVocab(path, added_tokens_path)
  1355. tokens, scores, toktypes = zip(*vocab.all_tokens())
  1356. assert len(tokens) == vocab.vocab_size
  1357. self.vocab_size = vocab.vocab_size
  1358. # we need this to validate the size of the token_type embeddings
  1359. # though currently we are passing all zeros to the token_type embeddings
  1360. n_token_types = len(set(toktypes))
  1361. self.gguf_writer.add_token_type_count(n_token_types)
  1362. # convert to phantom space vocab
  1363. def phantom(tok, typ):
  1364. if tok.startswith(b"[") and tok.endswith(b"]"):
  1365. return tok
  1366. if tok.startswith(b"##"):
  1367. return tok[2:]
  1368. return b"\xe2\x96\x81" + tok
  1369. tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
  1370. # set up bos and eos tokens (cls and sep)
  1371. self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
  1372. self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
  1373. # add vocab to gguf
  1374. self.gguf_writer.add_tokenizer_model("bert")
  1375. self.gguf_writer.add_token_list(tokens)
  1376. self.gguf_writer.add_token_scores(scores)
  1377. self.gguf_writer.add_token_types(toktypes)
  1378. # handle special tokens
  1379. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1380. special_vocab.add_to_gguf(self.gguf_writer)
  1381. def write_tensors(self):
  1382. tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
  1383. tensors = dict(self.get_tensors())
  1384. for name, data_torch in tensors.items():
  1385. # we are only using BERT for embeddings so we don't need the pooling layer
  1386. if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
  1387. continue # we don't need these
  1388. # map tensor names
  1389. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1390. if new_name is None:
  1391. print(f"Can not map tensor {name!r}")
  1392. sys.exit()
  1393. data = data_torch.squeeze().numpy()
  1394. n_dims = len(data.shape)
  1395. new_dtype: type[np.floating[Any]]
  1396. if (
  1397. self.ftype == 1 and name.endswith(".weight") and n_dims == 2
  1398. and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
  1399. ):
  1400. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1401. new_dtype = np.float16
  1402. else:
  1403. # if f32 desired, convert any float16 to float32
  1404. new_dtype = np.float32
  1405. print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
  1406. if data.dtype != new_dtype:
  1407. data = data.astype(new_dtype)
  1408. self.gguf_writer.add_tensor(new_name, data)
  1409. class NomicBertModel(BertModel):
  1410. def __init__(self, *args, **kwargs):
  1411. super().__init__(*args, **kwargs)
  1412. # the HF config claims n_ctx=8192, but it uses RoPE scaling
  1413. self.hparams["n_ctx"] = 2048
  1414. # SwigLU activation
  1415. assert self.hparams["activation_function"] == "swiglu"
  1416. # this doesn't do anything in the HF version
  1417. assert self.hparams["causal"] is False
  1418. # no bias tensors
  1419. assert self.hparams["qkv_proj_bias"] is False
  1420. assert self.hparams["mlp_fc1_bias"] is False
  1421. assert self.hparams["mlp_fc2_bias"] is False
  1422. # norm at end of layer
  1423. assert self.hparams["prenorm"] is False
  1424. # standard RoPE
  1425. assert self.hparams["rotary_emb_fraction"] == 1.0
  1426. assert self.hparams["rotary_emb_interleaved"] is False
  1427. assert self.hparams["rotary_emb_scale_base"] is None
  1428. def set_gguf_parameters(self):
  1429. super().set_gguf_parameters()
  1430. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  1431. def get_tensors(self):
  1432. assert self.vocab_size is not None
  1433. for name, data in super().get_tensors():
  1434. # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
  1435. if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
  1436. rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
  1437. assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
  1438. data = data[:self.vocab_size, :]
  1439. yield name, data
  1440. class GemmaModel(Model):
  1441. def set_vocab(self):
  1442. self._set_vocab_sentencepiece()
  1443. def set_gguf_parameters(self):
  1444. hparams = self.hparams
  1445. block_count = hparams["num_hidden_layers"]
  1446. self.gguf_writer.add_name(self.dir_model.name)
  1447. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  1448. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1449. self.gguf_writer.add_block_count(block_count)
  1450. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1451. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1452. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
  1453. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1454. self.gguf_writer.add_key_length(hparams["head_dim"])
  1455. self.gguf_writer.add_value_length(hparams["head_dim"])
  1456. self.gguf_writer.add_file_type(self.ftype)
  1457. def write_tensors(self):
  1458. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1459. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1460. for name, data_torch in self.get_tensors():
  1461. # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
  1462. if name.endswith("norm.weight"):
  1463. data_torch = data_torch + 1
  1464. old_dtype = data_torch.dtype
  1465. # convert any unsupported data types to float32
  1466. if data_torch.dtype not in (torch.float16, torch.float32):
  1467. data_torch = data_torch.to(torch.float32)
  1468. data = data_torch.squeeze().numpy()
  1469. # map tensor names
  1470. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1471. if new_name is None:
  1472. print(f"Can not map tensor {name!r}")
  1473. sys.exit()
  1474. n_dims = len(data.shape)
  1475. data_dtype = data.dtype
  1476. data = data.astype(np.float32)
  1477. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1478. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1479. data = data.astype(np.float16)
  1480. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1481. self.gguf_writer.add_tensor(new_name, data)
  1482. ###### CONVERSION LOGIC ######
  1483. def parse_args() -> argparse.Namespace:
  1484. parser = argparse.ArgumentParser(
  1485. description="Convert a huggingface model to a GGML compatible file")
  1486. parser.add_argument(
  1487. "--vocab-only", action="store_true",
  1488. help="extract only the vocab",
  1489. )
  1490. parser.add_argument(
  1491. "--awq-path", type=Path, default=None,
  1492. help="Path to scale awq cache file")
  1493. parser.add_argument(
  1494. "--outfile", type=Path,
  1495. help="path to write to; default: based on input",
  1496. )
  1497. parser.add_argument(
  1498. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  1499. help="output format - use f32 for float32, f16 for float16",
  1500. )
  1501. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  1502. parser.add_argument(
  1503. "model", type=Path,
  1504. help="directory containing model file",
  1505. )
  1506. return parser.parse_args()
  1507. def main() -> None:
  1508. args = parse_args()
  1509. dir_model = args.model
  1510. if args.awq_path:
  1511. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1512. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1513. tmp_model_path = args.model / "weighted_model"
  1514. dir_model = tmp_model_path
  1515. if tmp_model_path.is_dir():
  1516. print(f"{tmp_model_path} exists as a weighted model.")
  1517. else:
  1518. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1519. print("Saving new weighted model ...")
  1520. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1521. print(f"Saved weighted model at {tmp_model_path}.")
  1522. if not dir_model.is_dir():
  1523. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  1524. sys.exit(1)
  1525. ftype_map = {
  1526. "f32": gguf.GGMLQuantizationType.F32,
  1527. "f16": gguf.GGMLQuantizationType.F16,
  1528. }
  1529. if args.outfile is not None:
  1530. fname_out = args.outfile
  1531. else:
  1532. # output in the same directory as the model by default
  1533. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1534. print(f"Loading model: {dir_model.name}")
  1535. hparams = Model.load_hparams(dir_model)
  1536. with torch.inference_mode():
  1537. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1538. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1539. print("Set model parameters")
  1540. model_instance.set_gguf_parameters()
  1541. print("Set model tokenizer")
  1542. model_instance.set_vocab()
  1543. if args.vocab_only:
  1544. print(f"Exporting model vocab to '{fname_out}'")
  1545. model_instance.write_vocab()
  1546. else:
  1547. print(f"Exporting model to '{fname_out}'")
  1548. model_instance.write()
  1549. print(f"Model successfully exported to '{fname_out}'")
  1550. if __name__ == '__main__':
  1551. main()