convert-hf-to-gguf.py 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. from convert import HfVocab
  20. ###### MODEL DEFINITIONS ######
  21. class SentencePieceTokenTypes(IntEnum):
  22. NORMAL = 1
  23. UNKNOWN = 2
  24. CONTROL = 3
  25. USER_DEFINED = 4
  26. UNUSED = 5
  27. BYTE = 6
  28. class Model:
  29. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  30. self.dir_model = dir_model
  31. self.ftype = ftype
  32. self.fname_out = fname_out
  33. self.is_big_endian = is_big_endian
  34. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  35. self.is_safetensors = self._is_model_safetensors()
  36. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  37. self.part_names = self._get_part_names()
  38. self.hparams = Model.load_hparams(self.dir_model)
  39. self.model_arch = self._get_model_architecture()
  40. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  41. self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
  42. def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
  43. key = next((k for k in keys if k in self.hparams), None)
  44. if key is not None:
  45. return self.hparams[key]
  46. if optional:
  47. return None
  48. raise KeyError(f"could not find any of: {keys}")
  49. def set_vocab(self):
  50. self._set_vocab_gpt2()
  51. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  52. for part_name in self.part_names:
  53. print(f"gguf: loading model part '{part_name}'")
  54. ctx: ContextManager[Any]
  55. if self.is_safetensors:
  56. from safetensors import safe_open
  57. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  58. else:
  59. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  60. with ctx as model_part:
  61. for name in model_part.keys():
  62. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  63. yield name, data
  64. def set_gguf_parameters(self):
  65. self.gguf_writer.add_name(self.dir_model.name)
  66. self.gguf_writer.add_block_count(self.block_count)
  67. if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
  68. self.gguf_writer.add_context_length(n_ctx)
  69. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  70. self.gguf_writer.add_embedding_length(n_embd)
  71. if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
  72. self.gguf_writer.add_feed_forward_length(n_ff)
  73. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  74. self.gguf_writer.add_head_count(n_head)
  75. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  76. self.gguf_writer.add_head_count_kv(n_head_kv)
  77. if (rope_theta := self.hparams.get("rope_theta")) is not None:
  78. self.gguf_writer.add_rope_freq_base(rope_theta)
  79. if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  80. self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
  81. if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
  82. self.gguf_writer.add_layer_norm_eps(f_norm_eps)
  83. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  84. self.gguf_writer.add_expert_count(n_experts)
  85. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  86. self.gguf_writer.add_expert_used_count(n_experts_used)
  87. self.gguf_writer.add_file_type(self.ftype)
  88. def write_tensors(self):
  89. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  90. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  91. for name, data_torch in self.get_tensors():
  92. # we don't need these
  93. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  94. continue
  95. old_dtype = data_torch.dtype
  96. # convert any unsupported data types to float32
  97. if data_torch.dtype not in (torch.float16, torch.float32):
  98. data_torch = data_torch.to(torch.float32)
  99. data = data_torch.squeeze().numpy()
  100. # map tensor names
  101. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  102. if new_name is None:
  103. print(f"Can not map tensor {name!r}")
  104. sys.exit()
  105. n_dims = len(data.shape)
  106. data_dtype = data.dtype
  107. # if f32 desired, convert any float16 to float32
  108. if self.ftype == 0 and data_dtype == np.float16:
  109. data = data.astype(np.float32)
  110. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  111. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  112. data = data.astype(np.float32)
  113. # if f16 desired, convert any float32 2-dim weight tensors to float16
  114. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  115. data = data.astype(np.float16)
  116. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  117. self.gguf_writer.add_tensor(new_name, data)
  118. def write(self):
  119. self.write_tensors()
  120. self.gguf_writer.write_header_to_file()
  121. self.gguf_writer.write_kv_data_to_file()
  122. self.gguf_writer.write_tensors_to_file()
  123. self.gguf_writer.close()
  124. def write_vocab(self):
  125. self.gguf_writer.write_header_to_file()
  126. self.gguf_writer.write_kv_data_to_file()
  127. self.gguf_writer.close()
  128. @staticmethod
  129. def count_model_parts(dir_model: Path, prefix: str) -> int:
  130. num_parts = 0
  131. for filename in os.listdir(dir_model):
  132. if filename.endswith(prefix):
  133. num_parts += 1
  134. return num_parts
  135. @staticmethod
  136. def load_hparams(dir_model):
  137. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  138. return json.load(f)
  139. @staticmethod
  140. def from_model_architecture(model_architecture):
  141. if model_architecture == "GPTNeoXForCausalLM":
  142. return GPTNeoXModel
  143. if model_architecture == "BloomForCausalLM":
  144. return BloomModel
  145. if model_architecture == "MPTForCausalLM":
  146. return MPTModel
  147. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  148. return BaichuanModel
  149. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  150. return FalconModel
  151. if model_architecture == "GPTBigCodeForCausalLM":
  152. return StarCoderModel
  153. if model_architecture == "GPTRefactForCausalLM":
  154. return RefactModel
  155. if model_architecture == "PersimmonForCausalLM":
  156. return PersimmonModel
  157. if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  158. return StableLMModel
  159. if model_architecture == "QWenLMHeadModel":
  160. return QwenModel
  161. if model_architecture == "Qwen2ForCausalLM":
  162. return Model
  163. if model_architecture == "MixtralForCausalLM":
  164. return MixtralModel
  165. if model_architecture == "GPT2LMHeadModel":
  166. return GPT2Model
  167. if model_architecture == "PhiForCausalLM":
  168. return Phi2Model
  169. if model_architecture == "PlamoForCausalLM":
  170. return PlamoModel
  171. if model_architecture == "CodeShellForCausalLM":
  172. return CodeShellModel
  173. if model_architecture == "OrionForCausalLM":
  174. return OrionModel
  175. if model_architecture == "InternLM2ForCausalLM":
  176. return InternLM2Model
  177. if model_architecture == "MiniCPMForCausalLM":
  178. return MiniCPMModel
  179. if model_architecture == "BertModel":
  180. return BertModel
  181. if model_architecture == "NomicBertModel":
  182. return NomicBertModel
  183. if model_architecture == "GemmaForCausalLM":
  184. return GemmaModel
  185. if model_architecture == "Starcoder2ForCausalLM":
  186. return Model
  187. return Model
  188. def _is_model_safetensors(self) -> bool:
  189. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  190. def _get_part_names(self):
  191. if self.is_safetensors:
  192. if self.num_parts == 1: # there's only one .safetensors file
  193. return ("model.safetensors",)
  194. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  195. if self.num_parts == 1: # there's only one .bin file
  196. return ("pytorch_model.bin",)
  197. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  198. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  199. arch = self.hparams["architectures"][0]
  200. if arch == "GPTNeoXForCausalLM":
  201. return gguf.MODEL_ARCH.GPTNEOX
  202. if arch == "BloomForCausalLM":
  203. return gguf.MODEL_ARCH.BLOOM
  204. if arch == "MPTForCausalLM":
  205. return gguf.MODEL_ARCH.MPT
  206. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  207. return gguf.MODEL_ARCH.BAICHUAN
  208. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  209. return gguf.MODEL_ARCH.FALCON
  210. if arch == "GPTBigCodeForCausalLM":
  211. return gguf.MODEL_ARCH.STARCODER
  212. if arch == "GPTRefactForCausalLM":
  213. return gguf.MODEL_ARCH.REFACT
  214. if arch == "PersimmonForCausalLM":
  215. return gguf.MODEL_ARCH.PERSIMMON
  216. if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  217. return gguf.MODEL_ARCH.STABLELM
  218. if arch == "QWenLMHeadModel":
  219. return gguf.MODEL_ARCH.QWEN
  220. if arch == "Qwen2ForCausalLM":
  221. return gguf.MODEL_ARCH.QWEN2
  222. if arch == "MixtralForCausalLM":
  223. return gguf.MODEL_ARCH.LLAMA
  224. if arch == "GPT2LMHeadModel":
  225. return gguf.MODEL_ARCH.GPT2
  226. if arch == "PhiForCausalLM":
  227. return gguf.MODEL_ARCH.PHI2
  228. if arch == "PlamoForCausalLM":
  229. return gguf.MODEL_ARCH.PLAMO
  230. if arch == "CodeShellForCausalLM":
  231. return gguf.MODEL_ARCH.CODESHELL
  232. if arch == "OrionForCausalLM":
  233. return gguf.MODEL_ARCH.ORION
  234. if arch == "InternLM2ForCausalLM":
  235. return gguf.MODEL_ARCH.INTERNLM2
  236. if arch == "MiniCPMForCausalLM":
  237. return gguf.MODEL_ARCH.MINICPM
  238. if arch == "BertModel":
  239. return gguf.MODEL_ARCH.BERT
  240. if arch == "NomicBertModel":
  241. return gguf.MODEL_ARCH.NOMIC_BERT
  242. if arch == "GemmaForCausalLM":
  243. return gguf.MODEL_ARCH.GEMMA
  244. if arch == "Starcoder2ForCausalLM":
  245. return gguf.MODEL_ARCH.STARCODER2
  246. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  247. def _set_vocab_gpt2(self):
  248. dir_model = self.dir_model
  249. hparams = self.hparams
  250. tokens: list[bytearray] = []
  251. toktypes: list[int] = []
  252. from transformers import AutoTokenizer
  253. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  254. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  255. assert max(tokenizer.vocab.values()) < vocab_size
  256. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  257. added_vocab = tokenizer.get_added_vocab()
  258. for i in range(vocab_size):
  259. if i not in reverse_vocab:
  260. pad_token = f"[PAD{i}]".encode('utf-8')
  261. tokens.append(bytearray(pad_token))
  262. toktypes.append(gguf.TokenType.USER_DEFINED)
  263. elif reverse_vocab[i] in added_vocab:
  264. tokens.append(reverse_vocab[i])
  265. if tokenizer.added_tokens_decoder[i].special:
  266. toktypes.append(gguf.TokenType.CONTROL)
  267. else:
  268. toktypes.append(gguf.TokenType.USER_DEFINED)
  269. else:
  270. tokens.append(reverse_vocab[i])
  271. toktypes.append(gguf.TokenType.NORMAL)
  272. self.gguf_writer.add_tokenizer_model("gpt2")
  273. self.gguf_writer.add_token_list(tokens)
  274. self.gguf_writer.add_token_types(toktypes)
  275. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  276. special_vocab.add_to_gguf(self.gguf_writer)
  277. def _set_vocab_qwen(self):
  278. dir_model = self.dir_model
  279. hparams = self.hparams
  280. tokens: list[bytearray] = []
  281. toktypes: list[int] = []
  282. from transformers import AutoTokenizer
  283. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  284. vocab_size = hparams["vocab_size"]
  285. assert max(tokenizer.get_vocab().values()) < vocab_size
  286. merges = []
  287. vocab = {}
  288. mergeable_ranks = tokenizer.mergeable_ranks
  289. for token, rank in mergeable_ranks.items():
  290. vocab[QwenModel.token_bytes_to_string(token)] = rank
  291. if len(token) == 1:
  292. continue
  293. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  294. assert len(merged) == 2
  295. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  296. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  297. added_vocab = tokenizer.special_tokens
  298. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
  299. for i in range(vocab_size):
  300. if i not in reverse_vocab:
  301. pad_token = f"[PAD{i}]".encode("utf-8")
  302. tokens.append(bytearray(pad_token))
  303. toktypes.append(gguf.TokenType.USER_DEFINED)
  304. elif reverse_vocab[i] in added_vocab:
  305. tokens.append(reverse_vocab[i])
  306. toktypes.append(gguf.TokenType.CONTROL)
  307. else:
  308. tokens.append(reverse_vocab[i])
  309. toktypes.append(gguf.TokenType.NORMAL)
  310. self.gguf_writer.add_tokenizer_model("gpt2")
  311. self.gguf_writer.add_token_list(tokens)
  312. self.gguf_writer.add_token_types(toktypes)
  313. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  314. special_vocab.merges = merges
  315. # only add special tokens when they were not already loaded from config.json
  316. if len(special_vocab.special_token_ids) == 0:
  317. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  318. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  319. # this one is usually not in config.json anyway
  320. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  321. special_vocab.add_to_gguf(self.gguf_writer)
  322. def _set_vocab_sentencepiece(self):
  323. from sentencepiece import SentencePieceProcessor
  324. tokenizer_path = self.dir_model / 'tokenizer.model'
  325. tokens: list[bytes] = []
  326. scores: list[float] = []
  327. toktypes: list[int] = []
  328. if not tokenizer_path.is_file():
  329. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  330. sys.exit(1)
  331. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  332. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  333. for token_id in range(vocab_size):
  334. piece = tokenizer.id_to_piece(token_id)
  335. text = piece.encode("utf-8")
  336. score = tokenizer.get_score(token_id)
  337. toktype = SentencePieceTokenTypes.NORMAL
  338. if tokenizer.is_unknown(token_id):
  339. toktype = SentencePieceTokenTypes.UNKNOWN
  340. elif tokenizer.is_control(token_id):
  341. toktype = SentencePieceTokenTypes.CONTROL
  342. elif tokenizer.is_unused(token_id):
  343. toktype = SentencePieceTokenTypes.UNUSED
  344. elif tokenizer.is_byte(token_id):
  345. toktype = SentencePieceTokenTypes.BYTE
  346. tokens.append(text)
  347. scores.append(score)
  348. toktypes.append(toktype)
  349. added_tokens_file = self.dir_model / 'added_tokens.json'
  350. if added_tokens_file.is_file():
  351. with open(added_tokens_file, "r", encoding="utf-8") as f:
  352. added_tokens_json = json.load(f)
  353. for key in added_tokens_json:
  354. tokens.append(key.encode("utf-8"))
  355. scores.append(-1000.0)
  356. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  357. self.gguf_writer.add_tokenizer_model("llama")
  358. self.gguf_writer.add_token_list(tokens)
  359. self.gguf_writer.add_token_scores(scores)
  360. self.gguf_writer.add_token_types(toktypes)
  361. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  362. special_vocab.add_to_gguf(self.gguf_writer)
  363. def _set_vocab_hf(self):
  364. path = self.dir_model
  365. added_tokens_path = self.dir_model
  366. vocab = HfVocab(
  367. path, added_tokens_path if added_tokens_path.exists() else None
  368. )
  369. tokens = []
  370. scores = []
  371. toktypes = []
  372. for text, score, toktype in vocab.all_tokens():
  373. tokens.append(text)
  374. scores.append(score)
  375. toktypes.append(toktype)
  376. assert len(tokens) == vocab.vocab_size
  377. self.gguf_writer.add_tokenizer_model("llama")
  378. self.gguf_writer.add_token_list(tokens)
  379. self.gguf_writer.add_token_scores(scores)
  380. self.gguf_writer.add_token_types(toktypes)
  381. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  382. special_vocab.add_to_gguf(self.gguf_writer)
  383. class GPTNeoXModel(Model):
  384. def set_gguf_parameters(self):
  385. block_count = self.hparams["num_hidden_layers"]
  386. self.gguf_writer.add_name(self.dir_model.name)
  387. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  388. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  389. self.gguf_writer.add_block_count(block_count)
  390. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  391. self.gguf_writer.add_rope_dimension_count(
  392. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  393. )
  394. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  395. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  396. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  397. class BloomModel(Model):
  398. def set_gguf_parameters(self):
  399. self.gguf_writer.add_name("Bloom")
  400. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  401. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  402. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  403. self.gguf_writer.add_embedding_length(n_embed)
  404. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  405. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  406. self.gguf_writer.add_head_count(n_head)
  407. self.gguf_writer.add_head_count_kv(n_head)
  408. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  409. self.gguf_writer.add_file_type(self.ftype)
  410. def write_tensors(self):
  411. block_count = self.hparams["n_layer"]
  412. tensors = dict(self.get_tensors())
  413. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  414. has_lm_head = True
  415. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  416. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  417. for name, data_torch in tensors.items():
  418. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  419. has_lm_head = False
  420. name = re.sub(r'transformer\.', '', name)
  421. old_dtype = data_torch.dtype
  422. # convert any unsupported data types to float32
  423. if data_torch.dtype not in (torch.float16, torch.float32):
  424. data_torch = data_torch.to(torch.float32)
  425. data = data_torch.squeeze().numpy()
  426. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  427. # Map bloom-style qkv_linear to gpt-style qkv_linear
  428. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  429. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  430. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  431. data = np.concatenate(
  432. (
  433. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  434. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  435. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  436. ),
  437. axis=0,
  438. )
  439. print("re-format attention.linear_qkv.weight")
  440. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  441. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  442. data = np.concatenate(
  443. (
  444. qkv_bias[:, 0, :].reshape((n_embed,)),
  445. qkv_bias[:, 1, :].reshape((n_embed,)),
  446. qkv_bias[:, 2, :].reshape((n_embed,)),
  447. ),
  448. axis=0,
  449. )
  450. print("re-format attention.linear_qkv.bias")
  451. # map tensor names
  452. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  453. if new_name is None:
  454. print(f"Can not map tensor {name!r}")
  455. sys.exit()
  456. n_dims = len(data.shape)
  457. data_dtype = data.dtype
  458. # if f32 desired, convert any float16 to float32
  459. if self.ftype == 0 and data_dtype == np.float16:
  460. data = data.astype(np.float32)
  461. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  462. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  463. data = data.astype(np.float32)
  464. # if f16 desired, convert any float32 2-dim weight tensors to float16
  465. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  466. data = data.astype(np.float16)
  467. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  468. self.gguf_writer.add_tensor(new_name, data)
  469. if not has_lm_head and name == "word_embeddings.weight":
  470. self.gguf_writer.add_tensor("output.weight", data)
  471. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  472. class MPTModel(Model):
  473. def set_gguf_parameters(self):
  474. block_count = self.hparams["n_layers"]
  475. self.gguf_writer.add_name(self.dir_model.name)
  476. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  477. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  478. self.gguf_writer.add_block_count(block_count)
  479. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  480. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  481. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  482. self.gguf_writer.add_head_count_kv(kv_n_heads)
  483. self.gguf_writer.add_layer_norm_eps(1e-5)
  484. if self.hparams["attn_config"]["clip_qkv"] is not None:
  485. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  486. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  487. def write_tensors(self):
  488. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  489. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  490. for name, data_torch in self.get_tensors():
  491. # we don't need these
  492. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  493. continue
  494. old_dtype = data_torch.dtype
  495. # convert any unsupported data types to float32
  496. if data_torch.dtype not in (torch.float16, torch.float32):
  497. data_torch = data_torch.to(torch.float32)
  498. data = data_torch.squeeze().numpy()
  499. # map tensor names
  500. if "scales" in name:
  501. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  502. if new_name is not None:
  503. new_name = new_name.replace("scales", "act.scales")
  504. else:
  505. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  506. if new_name is None:
  507. print(f"Can not map tensor {name!r}")
  508. sys.exit()
  509. n_dims = len(data.shape)
  510. data_dtype = data.dtype
  511. # if f32 desired, convert any float16 to float32
  512. if self.ftype == 0 and data_dtype == np.float16:
  513. data = data.astype(np.float32)
  514. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  515. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  516. data = data.astype(np.float32)
  517. # if f16 desired, convert any float32 2-dim weight tensors to float16
  518. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  519. data = data.astype(np.float16)
  520. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  521. self.gguf_writer.add_tensor(new_name, data)
  522. class OrionModel(Model):
  523. def set_vocab(self):
  524. self._set_vocab_sentencepiece()
  525. def set_gguf_parameters(self):
  526. block_count = self.hparams["num_hidden_layers"]
  527. head_count = self.hparams["num_attention_heads"]
  528. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  529. hf_repo = self.hparams.get("_name_or_path", "")
  530. ctx_length = 0
  531. if "max_sequence_length" in self.hparams:
  532. ctx_length = self.hparams["max_sequence_length"]
  533. elif "max_position_embeddings" in self.hparams:
  534. ctx_length = self.hparams["max_position_embeddings"]
  535. elif "model_max_length" in self.hparams:
  536. ctx_length = self.hparams["model_max_length"]
  537. else:
  538. print("gguf: can not find ctx length parameter.")
  539. sys.exit()
  540. self.gguf_writer.add_file_type(self.ftype)
  541. self.gguf_writer.add_name(self.dir_model.name)
  542. self.gguf_writer.add_source_hf_repo(hf_repo)
  543. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  544. self.gguf_writer.add_context_length(ctx_length)
  545. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  546. self.gguf_writer.add_block_count(block_count)
  547. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  548. self.gguf_writer.add_head_count(head_count)
  549. self.gguf_writer.add_head_count_kv(head_count_kv)
  550. # note: config provides rms norm but it is actually layer norm
  551. # ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571
  552. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  553. def write_tensors(self):
  554. # Collect tensors from generator object
  555. model_kv = dict(self.get_tensors())
  556. block_count = self.hparams["num_hidden_layers"]
  557. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  558. for name, data_torch in model_kv.items():
  559. # we don't need these
  560. if name.endswith(".rotary_emb.inv_freq"):
  561. continue
  562. old_dtype = data_torch.dtype
  563. # convert any unsupported data types to float32
  564. if data_torch.dtype not in (torch.float16, torch.float32):
  565. data_torch = data_torch.to(torch.float32)
  566. data = data_torch.squeeze().numpy()
  567. # map tensor names
  568. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  569. if new_name is None:
  570. print(f"Can not map tensor {name!r}")
  571. sys.exit()
  572. n_dims = len(data.shape)
  573. data_dtype = data.dtype
  574. # if f32 desired, convert any float16 to float32
  575. if self.ftype == 0 and data_dtype == np.float16:
  576. data = data.astype(np.float32)
  577. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  578. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  579. data = data.astype(np.float32)
  580. # if f16 desired, convert any float32 2-dim weight tensors to float16
  581. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  582. data = data.astype(np.float16)
  583. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  584. self.gguf_writer.add_tensor(new_name, data)
  585. class BaichuanModel(Model):
  586. def set_vocab(self):
  587. self._set_vocab_sentencepiece()
  588. def set_gguf_parameters(self):
  589. block_count = self.hparams["num_hidden_layers"]
  590. head_count = self.hparams["num_attention_heads"]
  591. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  592. hf_repo = self.hparams.get("_name_or_path", "")
  593. ctx_length = 0
  594. if "max_sequence_length" in self.hparams:
  595. ctx_length = self.hparams["max_sequence_length"]
  596. elif "max_position_embeddings" in self.hparams:
  597. ctx_length = self.hparams["max_position_embeddings"]
  598. elif "model_max_length" in self.hparams:
  599. ctx_length = self.hparams["model_max_length"]
  600. else:
  601. print("gguf: can not find ctx length parameter.")
  602. sys.exit()
  603. self.gguf_writer.add_name(self.dir_model.name)
  604. self.gguf_writer.add_source_hf_repo(hf_repo)
  605. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  606. self.gguf_writer.add_context_length(ctx_length)
  607. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  608. self.gguf_writer.add_block_count(block_count)
  609. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  610. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  611. self.gguf_writer.add_head_count(head_count)
  612. self.gguf_writer.add_head_count_kv(head_count_kv)
  613. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  614. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  615. if self.hparams["rope_scaling"].get("type") == "linear":
  616. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  617. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  618. def write_tensors(self):
  619. # Collect tensors from generator object
  620. model_kv = dict(self.get_tensors())
  621. block_count = self.hparams["num_hidden_layers"]
  622. head_count = self.hparams["num_attention_heads"]
  623. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  624. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  625. for i in range(block_count):
  626. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  627. print(f"Unpacking and permuting layer {i}")
  628. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  629. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  630. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  631. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  632. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  633. self._reverse_hf_part(w, 2)
  634. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  635. for name, data_torch in model_kv.items():
  636. # we don't need these
  637. if name.endswith(".rotary_emb.inv_freq"):
  638. continue
  639. old_dtype = data_torch.dtype
  640. # convert any unsupported data types to float32
  641. if data_torch.dtype not in (torch.float16, torch.float32):
  642. data_torch = data_torch.to(torch.float32)
  643. data = data_torch.squeeze().numpy()
  644. # map tensor names
  645. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  646. if new_name is None:
  647. print(f"Can not map tensor {name!r}")
  648. sys.exit()
  649. n_dims = len(data.shape)
  650. data_dtype = data.dtype
  651. # if f32 desired, convert any float16 to float32
  652. if self.ftype == 0 and data_dtype == np.float16:
  653. data = data.astype(np.float32)
  654. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  655. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  656. data = data.astype(np.float32)
  657. # if f16 desired, convert any float32 2-dim weight tensors to float16
  658. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  659. data = data.astype(np.float16)
  660. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  661. self.gguf_writer.add_tensor(new_name, data)
  662. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  663. if n_kv_head is not None and n_head != n_kv_head:
  664. n_head //= n_kv_head
  665. return (
  666. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  667. .swapaxes(1, 2)
  668. .reshape(weights.shape)
  669. )
  670. def _reverse_hf_permute_part(
  671. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  672. ) -> Tensor:
  673. r = weights.shape[0] // 3
  674. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  675. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  676. r = weights.shape[0] // 3
  677. return weights[r * n_part:r * n_part + r, ...]
  678. class FalconModel(Model):
  679. def set_gguf_parameters(self):
  680. block_count = self.hparams.get("num_hidden_layers")
  681. if block_count is None:
  682. block_count = self.hparams["n_layer"] # old name
  683. n_head = self.hparams.get("num_attention_heads")
  684. if n_head is None:
  685. n_head = self.hparams["n_head"] # old name
  686. n_head_kv = self.hparams.get("num_kv_heads")
  687. if n_head_kv is None:
  688. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  689. self.gguf_writer.add_name("Falcon")
  690. self.gguf_writer.add_context_length(2048) # not in config.json
  691. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  692. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  693. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  694. self.gguf_writer.add_block_count(block_count)
  695. self.gguf_writer.add_head_count(n_head)
  696. self.gguf_writer.add_head_count_kv(n_head_kv)
  697. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  698. self.gguf_writer.add_file_type(self.ftype)
  699. def write_tensors(self):
  700. block_count = self.hparams.get("num_hidden_layers")
  701. if block_count is None:
  702. block_count = self.hparams["n_layer"] # old name
  703. n_head = self.hparams.get("num_attention_heads")
  704. if n_head is None:
  705. n_head = self.hparams["n_head"] # old name
  706. n_head_kv = self.hparams.get("num_kv_heads")
  707. if n_head_kv is None:
  708. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  709. head_dim = self.hparams["hidden_size"] // n_head
  710. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  711. for name, data_torch in self.get_tensors():
  712. old_dtype = data_torch.dtype
  713. # convert any unsupported data types to float32
  714. if data_torch.dtype not in (torch.float16, torch.float32):
  715. data_torch = data_torch.to(torch.float32)
  716. # QKV tensor transform
  717. # The original query_key_value tensor contains n_head_kv "kv groups",
  718. # each consisting of n_head/n_head_kv query weights followed by one key
  719. # and one value weight (shared by all query heads in the kv group).
  720. # This layout makes it a big pain to work with in GGML.
  721. # So we rearrange them here,, so that we have n_head query weights
  722. # followed by n_head_kv key weights followed by n_head_kv value weights,
  723. # in contiguous fashion.
  724. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  725. if "query_key_value" in name:
  726. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  727. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  728. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  729. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  730. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  731. data = data_torch.squeeze().numpy()
  732. # map tensor names
  733. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  734. if new_name is None:
  735. print(f"Can not map tensor {name!r}")
  736. sys.exit()
  737. n_dims = len(data.shape)
  738. data_dtype = data.dtype
  739. # if f32 desired, convert any float16 to float32
  740. if self.ftype == 0 and data_dtype == np.float16:
  741. data = data.astype(np.float32)
  742. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  743. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  744. data = data.astype(np.float32)
  745. # if f16 desired, convert any float32 2-dim weight tensors to float16
  746. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  747. data = data.astype(np.float16)
  748. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  749. self.gguf_writer.add_tensor(new_name, data)
  750. class StarCoderModel(Model):
  751. def set_gguf_parameters(self):
  752. block_count = self.hparams["n_layer"]
  753. self.gguf_writer.add_name("StarCoder")
  754. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  755. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  756. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  757. self.gguf_writer.add_block_count(block_count)
  758. self.gguf_writer.add_head_count(self.hparams["n_head"])
  759. self.gguf_writer.add_head_count_kv(1)
  760. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  761. self.gguf_writer.add_file_type(self.ftype)
  762. class RefactModel(Model):
  763. def set_gguf_parameters(self):
  764. hidden_dim = self.hparams["n_embd"]
  765. inner_dim = 4 * hidden_dim
  766. hidden_dim = int(2 * inner_dim / 3)
  767. multiple_of = 256
  768. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  769. block_count = self.hparams["n_layer"]
  770. self.gguf_writer.add_name("Refact")
  771. # refact uses Alibi. So this is from config.json which might be used by training.
  772. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  773. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  774. self.gguf_writer.add_feed_forward_length(ff_dim)
  775. self.gguf_writer.add_block_count(block_count)
  776. self.gguf_writer.add_head_count(self.hparams["n_head"])
  777. self.gguf_writer.add_head_count_kv(1)
  778. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  779. self.gguf_writer.add_file_type(self.ftype)
  780. def write_tensors(self):
  781. hidden_dim = self.hparams["n_embd"]
  782. inner_dim = 4 * hidden_dim
  783. hidden_dim = int(2 * inner_dim / 3)
  784. multiple_of = 256
  785. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  786. n_head = self.hparams["n_head"]
  787. n_head_kv = 1
  788. head_dim = self.hparams["n_embd"] // n_head
  789. block_count = self.hparams["n_layer"]
  790. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  791. tensors = dict(self.get_tensors())
  792. for i in range(block_count):
  793. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  794. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  795. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  796. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  797. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  798. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  799. del tensors[f"transformer.h.{i}.attn.q.weight"]
  800. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  801. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  802. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  803. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  804. for name, data_torch in tensors.items():
  805. old_dtype = data_torch.dtype
  806. # convert any unsupported data types to float32
  807. if data_torch.dtype not in (torch.float16, torch.float32):
  808. data_torch = data_torch.to(torch.float32)
  809. data = data_torch.squeeze().numpy()
  810. # map tensor names
  811. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  812. if new_name is None:
  813. print(f"Can not map tensor {name!r}")
  814. sys.exit()
  815. n_dims = len(data.shape)
  816. data_dtype = data.dtype
  817. # if f32 desired, convert any float16 to float32
  818. if self.ftype == 0 and data_dtype == np.float16:
  819. data = data.astype(np.float32)
  820. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  821. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  822. data = data.astype(np.float32)
  823. # if f16 desired, convert any float32 2-dim weight tensors to float16
  824. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  825. data = data.astype(np.float16)
  826. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  827. self.gguf_writer.add_tensor(new_name, data)
  828. class PersimmonModel(Model):
  829. def set_gguf_parameters(self):
  830. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  831. head_count = self.hparams["num_attention_heads"]
  832. head_count_kv = head_count
  833. hidden_size = self.hparams["hidden_size"]
  834. self.gguf_writer.add_name('persimmon-8b-chat')
  835. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  836. self.gguf_writer.add_embedding_length(hidden_size)
  837. self.gguf_writer.add_block_count(block_count)
  838. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  839. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  840. # than the head size?
  841. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  842. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  843. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  844. self.gguf_writer.add_head_count(head_count)
  845. self.gguf_writer.add_head_count_kv(head_count_kv)
  846. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  847. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  848. def set_vocab(self):
  849. self._set_vocab_sentencepiece()
  850. # self.gguf_writer.add_bos_token_id(71013)
  851. # self.gguf_writer.add_eos_token_id(71013)
  852. def write_tensors(self):
  853. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  854. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  855. for name, data_torch in self.get_tensors():
  856. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  857. continue
  858. old_dtype = data_torch.dtype
  859. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  860. data = data_torch.to(torch.float32).squeeze().numpy()
  861. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  862. if new_name is None:
  863. print(f"Can not map tensor {name!r}")
  864. sys.exit()
  865. n_dims = len(data.shape)
  866. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  867. self.gguf_writer.add_tensor(new_name, data)
  868. class StableLMModel(Model):
  869. def set_vocab(self):
  870. if (self.dir_model / "tokenizer.json").is_file():
  871. self._set_vocab_gpt2()
  872. else:
  873. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  874. self._set_vocab_qwen()
  875. def set_gguf_parameters(self):
  876. hparams = self.hparams
  877. block_count = hparams["num_hidden_layers"]
  878. self.gguf_writer.add_name(self.dir_model.name)
  879. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  880. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  881. self.gguf_writer.add_block_count(block_count)
  882. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  883. rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
  884. self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  885. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  886. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  887. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
  888. class MixtralModel(Model):
  889. def set_vocab(self):
  890. self._set_vocab_sentencepiece()
  891. class MiniCPMModel(Model):
  892. def set_gguf_parameters(self):
  893. block_count = self.hparams["num_hidden_layers"]
  894. self.gguf_writer.add_name("MiniCPM")
  895. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  896. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  897. self.gguf_writer.add_block_count(block_count)
  898. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  899. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  900. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  901. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  902. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  903. self.gguf_writer.add_file_type(self.ftype)
  904. def set_vocab(self):
  905. self._set_vocab_hf()
  906. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  907. if n_kv_head is not None and n_head != n_kv_head:
  908. n_head //= n_kv_head
  909. return (
  910. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  911. .swapaxes(1, 2)
  912. .reshape(weights.shape)
  913. )
  914. def write_tensors(self):
  915. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  916. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  917. n_head = self.hparams.get("num_attention_heads")
  918. n_kv_head = self.hparams.get("num_key_value_heads")
  919. for name, data_torch in self.get_tensors():
  920. # we don't need these
  921. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  922. continue
  923. old_dtype = data_torch.dtype
  924. # convert any unsupported data types to float32
  925. if data_torch.dtype not in (torch.float16, torch.float32):
  926. data_torch = data_torch.to(torch.float32)
  927. # HF models permute some of the tensors, so we need to undo that
  928. if name.endswith(("q_proj.weight")):
  929. data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
  930. if name.endswith(("k_proj.weight")):
  931. data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
  932. data = data_torch.squeeze().numpy()
  933. # map tensor names
  934. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  935. if new_name is None:
  936. print(f"Can not map tensor {name!r}")
  937. sys.exit()
  938. n_dims = len(data.shape)
  939. data_dtype = data.dtype
  940. # if f32 desired, convert any float16 to float32
  941. if self.ftype == 0 and data_dtype == np.float16:
  942. data = data.astype(np.float32)
  943. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  944. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  945. data = data.astype(np.float32)
  946. # if f16 desired, convert any float32 2-dim weight tensors to float16
  947. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  948. data = data.astype(np.float16)
  949. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  950. self.gguf_writer.add_tensor(new_name, data)
  951. class QwenModel(Model):
  952. @staticmethod
  953. def token_bytes_to_string(b):
  954. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  955. byte_encoder = bytes_to_unicode()
  956. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  957. @staticmethod
  958. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  959. parts = [bytes([b]) for b in token]
  960. while True:
  961. min_idx = None
  962. min_rank = None
  963. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  964. rank = mergeable_ranks.get(pair[0] + pair[1])
  965. if rank is not None and (min_rank is None or rank < min_rank):
  966. min_idx = i
  967. min_rank = rank
  968. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  969. break
  970. assert min_idx is not None
  971. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  972. return parts
  973. def set_vocab(self):
  974. self._set_vocab_qwen()
  975. def set_gguf_parameters(self):
  976. self.gguf_writer.add_name("Qwen")
  977. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  978. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  979. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  980. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  981. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  982. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  983. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  984. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  985. def write_tensors(self):
  986. block_count = self.hparams["num_hidden_layers"]
  987. model_kv = dict(self.get_tensors())
  988. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  989. for name, data_torch in model_kv.items():
  990. # we don't need these
  991. if name.endswith(".rotary_emb.inv_freq"):
  992. continue
  993. old_dtype = data_torch.dtype
  994. # convert any unsupported data types to float32
  995. if data_torch.dtype not in (torch.float16, torch.float32):
  996. data_torch = data_torch.to(torch.float32)
  997. data = data_torch.squeeze().numpy()
  998. # map tensor names
  999. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1000. if new_name is None:
  1001. print(f"Can not map tensor {name!r}")
  1002. sys.exit()
  1003. n_dims = len(data.shape)
  1004. data_dtype = data.dtype
  1005. # if f32 desired, convert any float16 to float32
  1006. if self.ftype == 0 and data_dtype == np.float16:
  1007. data = data.astype(np.float32)
  1008. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1009. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1010. data = data.astype(np.float32)
  1011. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1012. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1013. data = data.astype(np.float16)
  1014. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1015. self.gguf_writer.add_tensor(new_name, data)
  1016. class GPT2Model(Model):
  1017. def set_gguf_parameters(self):
  1018. self.gguf_writer.add_name(self.dir_model.name)
  1019. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1020. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  1021. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1022. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1023. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1024. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1025. self.gguf_writer.add_file_type(self.ftype)
  1026. def write_tensors(self):
  1027. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1028. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1029. for name, data_torch in self.get_tensors():
  1030. # we don't need these
  1031. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")):
  1032. continue
  1033. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  1034. data_torch = data_torch.transpose(1, 0)
  1035. old_dtype = data_torch.dtype
  1036. # convert any unsupported data types to float32
  1037. if data_torch.dtype not in (torch.float16, torch.float32):
  1038. data_torch = data_torch.to(torch.float32)
  1039. data = data_torch.squeeze().numpy()
  1040. # map tensor names
  1041. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1042. if new_name is None:
  1043. print(f"Can not map tensor {name!r}")
  1044. sys.exit()
  1045. n_dims = len(data.shape)
  1046. data_dtype = data.dtype
  1047. # if f32 desired, convert any float16 to float32
  1048. if self.ftype == 0 and data_dtype == np.float16:
  1049. data = data.astype(np.float32)
  1050. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1051. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1052. data = data.astype(np.float32)
  1053. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1054. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1055. data = data.astype(np.float16)
  1056. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1057. self.gguf_writer.add_tensor(new_name, data)
  1058. # note: GPT2 output is tied to (same as) wte in original model
  1059. if new_name == "token_embd.weight":
  1060. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1061. self.gguf_writer.add_tensor("output.weight", data)
  1062. class Phi2Model(Model):
  1063. def set_gguf_parameters(self):
  1064. block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
  1065. rot_pct = self.find_hparam(["partial_rotary_factor"])
  1066. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  1067. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  1068. self.gguf_writer.add_name("Phi2")
  1069. self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
  1070. self.gguf_writer.add_embedding_length(n_embd)
  1071. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  1072. self.gguf_writer.add_block_count(block_count)
  1073. self.gguf_writer.add_head_count(n_head)
  1074. self.gguf_writer.add_head_count_kv(n_head)
  1075. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
  1076. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1077. self.gguf_writer.add_file_type(self.ftype)
  1078. self.gguf_writer.add_add_bos_token(False)
  1079. class PlamoModel(Model):
  1080. def set_vocab(self):
  1081. self._set_vocab_sentencepiece()
  1082. def set_gguf_parameters(self):
  1083. hparams = self.hparams
  1084. block_count = hparams["num_hidden_layers"]
  1085. self.gguf_writer.add_name("PLaMo")
  1086. self.gguf_writer.add_context_length(4096) # not in config.json
  1087. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1088. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1089. self.gguf_writer.add_block_count(block_count)
  1090. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1091. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  1092. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  1093. def shuffle_attn_q_weight(self, data_torch):
  1094. assert data_torch.size() == (5120, 5120)
  1095. data_torch = data_torch.reshape(8, 5, 128, 5120)
  1096. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  1097. data_torch = torch.reshape(data_torch, (5120, 5120))
  1098. return data_torch
  1099. def shuffle_attn_output_weight(self, data_torch):
  1100. assert data_torch.size() == (5120, 5120)
  1101. data_torch = data_torch.reshape(5120, 8, 5, 128)
  1102. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  1103. data_torch = torch.reshape(data_torch, (5120, 5120))
  1104. return data_torch
  1105. def write_tensors(self):
  1106. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  1107. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1108. for name, data_torch in self.get_tensors():
  1109. if "self_attn.rotary_emb.inv_freq" in name:
  1110. continue
  1111. # map tensor names
  1112. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1113. if new_name is None:
  1114. print(f"Can not map tensor {name!r}")
  1115. sys.exit()
  1116. # shuffle for broadcasting of gqa in ggml_mul_mat
  1117. if new_name.endswith("attn_q.weight"):
  1118. data_torch = self.shuffle_attn_q_weight(data_torch)
  1119. elif new_name.endswith("attn_output.weight"):
  1120. data_torch = self.shuffle_attn_output_weight(data_torch)
  1121. old_dtype = data_torch.dtype
  1122. # convert any unsupported data types to float32
  1123. if data_torch.dtype not in (torch.float16, torch.float32):
  1124. data_torch = data_torch.to(torch.float32)
  1125. data = data_torch.squeeze().numpy()
  1126. n_dims = len(data.shape)
  1127. data_dtype = data.dtype
  1128. # if f32 desired, convert any float16 to float32
  1129. if self.ftype == 0 and data_dtype == np.float16:
  1130. data = data.astype(np.float32)
  1131. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1132. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1133. data = data.astype(np.float32)
  1134. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1135. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1136. data = data.astype(np.float16)
  1137. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1138. self.gguf_writer.add_tensor(new_name, data)
  1139. class CodeShellModel(Model):
  1140. def set_gguf_parameters(self):
  1141. block_count = self.hparams["n_layer"]
  1142. self.gguf_writer.add_name("CodeShell")
  1143. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1144. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1145. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1146. self.gguf_writer.add_block_count(block_count)
  1147. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1148. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1149. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1150. self.gguf_writer.add_file_type(self.ftype)
  1151. self.gguf_writer.add_rope_freq_base(10000.0)
  1152. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1153. self.gguf_writer.add_rope_scaling_factor(1.0)
  1154. def write_tensors(self):
  1155. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1156. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1157. tensors = dict(self.get_tensors())
  1158. has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
  1159. for name, data_torch in tensors.items():
  1160. # we don't need these
  1161. if name.endswith((".attn.rotary_emb.inv_freq")):
  1162. continue
  1163. old_dtype = data_torch.dtype
  1164. # convert any unsupported data types to float32
  1165. if data_torch.dtype not in (torch.float16, torch.float32):
  1166. data_torch = data_torch.to(torch.float32)
  1167. data = data_torch.squeeze().numpy()
  1168. # map tensor names
  1169. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1170. if new_name is None:
  1171. print(f"Can not map tensor {name!r}")
  1172. sys.exit()
  1173. n_dims = len(data.shape)
  1174. data_dtype = data.dtype
  1175. # if f32 desired, convert any float16 to float32
  1176. if self.ftype == 0 and data_dtype == np.float16:
  1177. data = data.astype(np.float32)
  1178. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1179. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1180. data = data.astype(np.float32)
  1181. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1182. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1183. data = data.astype(np.float16)
  1184. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1185. self.gguf_writer.add_tensor(new_name, data)
  1186. if not has_lm_head and name == "transformer.wte.weight":
  1187. self.gguf_writer.add_tensor("output.weight", data)
  1188. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  1189. class InternLM2Model(Model):
  1190. def set_vocab(self):
  1191. # (TODO): Is there a better way?
  1192. # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
  1193. # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
  1194. # recognized as an empty string in C++.
  1195. from sentencepiece import SentencePieceProcessor
  1196. from sentencepiece import sentencepiece_model_pb2 as model
  1197. tokenizer_path = self.dir_model / 'tokenizer.model'
  1198. tokens: list[bytes] = []
  1199. scores: list[float] = []
  1200. toktypes: list[int] = []
  1201. if not tokenizer_path.is_file():
  1202. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  1203. sys.exit(1)
  1204. sentencepiece_model = model.ModelProto()
  1205. sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
  1206. add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
  1207. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  1208. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1209. for token_id in range(vocab_size):
  1210. piece = tokenizer.id_to_piece(token_id)
  1211. text = piece.encode("utf-8")
  1212. score = tokenizer.get_score(token_id)
  1213. if text == b"\x00":
  1214. # (TODO): fixme
  1215. # Hack here and replace the \x00 characters.
  1216. print(f"InternLM2 convert token '{text}' to '🐉'!")
  1217. text = "🐉"
  1218. toktype = SentencePieceTokenTypes.NORMAL
  1219. if tokenizer.is_unknown(token_id):
  1220. toktype = SentencePieceTokenTypes.UNKNOWN
  1221. elif tokenizer.is_control(token_id):
  1222. toktype = SentencePieceTokenTypes.CONTROL
  1223. elif tokenizer.is_unused(token_id):
  1224. toktype = SentencePieceTokenTypes.UNUSED
  1225. elif tokenizer.is_byte(token_id):
  1226. toktype = SentencePieceTokenTypes.BYTE
  1227. tokens.append(text)
  1228. scores.append(score)
  1229. toktypes.append(toktype)
  1230. added_tokens_file = self.dir_model / 'added_tokens.json'
  1231. if added_tokens_file.is_file():
  1232. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1233. added_tokens_json = json.load(f)
  1234. for key in added_tokens_json:
  1235. tokens.append(key.encode("utf-8"))
  1236. scores.append(-1000.0)
  1237. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  1238. self.gguf_writer.add_tokenizer_model("llama")
  1239. self.gguf_writer.add_token_list(tokens)
  1240. self.gguf_writer.add_token_scores(scores)
  1241. self.gguf_writer.add_token_types(toktypes)
  1242. self.gguf_writer.add_add_space_prefix(add_prefix)
  1243. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1244. old_eos = special_vocab.special_token_ids["eos"]
  1245. if "chat" in os.path.basename(self.dir_model.absolute()):
  1246. # For the chat model, we replace the eos with '<|im_end|>'.
  1247. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
  1248. print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
  1249. in chat mode so that the conversation can end normally.")
  1250. special_vocab.add_to_gguf(self.gguf_writer)
  1251. def _try_get_sft_eos(self, tokenizer):
  1252. unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
  1253. im_end_list = tokenizer.encode('<|im_end|>')
  1254. assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
  1255. if len(unused_145_list) == 1:
  1256. eos_token = unused_145_list[0]
  1257. if len(im_end_list) == 1:
  1258. eos_token = im_end_list[0]
  1259. return eos_token
  1260. def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
  1261. if n_head_kv is not None and n_head != n_head_kv:
  1262. n_head = n_head_kv
  1263. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1264. .swapaxes(1, 2)
  1265. .reshape(weights.shape))
  1266. def set_gguf_parameters(self):
  1267. self.gguf_writer.add_name("InternLM2")
  1268. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1269. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1270. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1271. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1272. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  1273. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1274. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1275. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1276. def post_write_tensors(self, tensor_map, name, data_torch):
  1277. old_dtype = data_torch.dtype
  1278. # convert any unsupported data types to float32
  1279. if data_torch.dtype not in (torch.float16, torch.float32):
  1280. data_torch = data_torch.to(torch.float32)
  1281. data = data_torch.squeeze().numpy()
  1282. # map tensor names
  1283. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1284. if new_name is None:
  1285. print(f"Can not map tensor {name!r}")
  1286. sys.exit()
  1287. n_dims = len(data.shape)
  1288. data_dtype = data.dtype
  1289. # if f32 desired, convert any float16 to float32
  1290. if self.ftype == 0 and data_dtype == np.float16:
  1291. data = data.astype(np.float32)
  1292. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1293. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1294. data = data.astype(np.float32)
  1295. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1296. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1297. data = data.astype(np.float16)
  1298. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1299. self.gguf_writer.add_tensor(new_name, data)
  1300. def write_tensors(self):
  1301. from einops import rearrange
  1302. num_heads = self.hparams.get("num_attention_heads")
  1303. num_kv_heads = self.hparams.get("num_key_value_heads")
  1304. hidden_size = self.hparams.get("hidden_size")
  1305. q_per_kv = num_heads // num_kv_heads
  1306. head_dim = hidden_size // num_heads
  1307. num_groups = num_heads // q_per_kv
  1308. block_count = self.hparams["num_hidden_layers"]
  1309. model_kv = dict(self.get_tensors())
  1310. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1311. qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
  1312. for name, data_torch in model_kv.items():
  1313. # we don't need these
  1314. if name.endswith(".rotary_emb.inv_freq"):
  1315. continue
  1316. if re.match(qkv_pattern, name):
  1317. bid = re.findall(qkv_pattern, name)[0]
  1318. qkv = data_torch
  1319. qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
  1320. q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
  1321. # The model weights of q and k equire additional reshape.
  1322. q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
  1323. k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
  1324. v = rearrange(v, " o g n i -> o (g n i)").T
  1325. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
  1326. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
  1327. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
  1328. else:
  1329. self.post_write_tensors(tensor_map, name, data_torch)
  1330. class BertModel(Model):
  1331. def __init__(self, *args, **kwargs):
  1332. super().__init__(*args, **kwargs)
  1333. self.vocab_size = None
  1334. def set_gguf_parameters(self):
  1335. super().set_gguf_parameters()
  1336. self.gguf_writer.add_causal_attention(False)
  1337. # get pooling path
  1338. with open(self.dir_model / "modules.json", encoding="utf-8") as f:
  1339. modules = json.load(f)
  1340. pooling_path = None
  1341. for mod in modules:
  1342. if mod["type"] == "sentence_transformers.models.Pooling":
  1343. pooling_path = mod["path"]
  1344. break
  1345. # get pooling type
  1346. pooling_type = gguf.PoolingType.NONE
  1347. if pooling_path is not None:
  1348. with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
  1349. pooling = json.load(f)
  1350. if pooling["pooling_mode_mean_tokens"]:
  1351. pooling_type = gguf.PoolingType.MEAN
  1352. elif pooling["pooling_mode_cls_token"]:
  1353. pooling_type = gguf.PoolingType.CLS
  1354. else:
  1355. raise NotImplementedError("Only MEAN and CLS pooling types supported")
  1356. self.gguf_writer.add_pooling_type(pooling_type.value)
  1357. def set_vocab(self):
  1358. path = self.dir_model
  1359. added_tokens_path = self.dir_model if self.dir_model.exists() else None
  1360. # use huggingface vocab to get all tokens
  1361. vocab = HfVocab(path, added_tokens_path)
  1362. tokens, scores, toktypes = zip(*vocab.all_tokens())
  1363. assert len(tokens) == vocab.vocab_size
  1364. self.vocab_size = vocab.vocab_size
  1365. # we need this to validate the size of the token_type embeddings
  1366. # though currently we are passing all zeros to the token_type embeddings
  1367. n_token_types = len(set(toktypes))
  1368. self.gguf_writer.add_token_type_count(n_token_types)
  1369. # convert to phantom space vocab
  1370. def phantom(tok, typ):
  1371. if tok.startswith(b"[") and tok.endswith(b"]"):
  1372. return tok
  1373. if tok.startswith(b"##"):
  1374. return tok[2:]
  1375. return b"\xe2\x96\x81" + tok
  1376. tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
  1377. # set up bos and eos tokens (cls and sep)
  1378. self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
  1379. self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
  1380. # add vocab to gguf
  1381. self.gguf_writer.add_tokenizer_model("bert")
  1382. self.gguf_writer.add_token_list(tokens)
  1383. self.gguf_writer.add_token_scores(scores)
  1384. self.gguf_writer.add_token_types(toktypes)
  1385. # handle special tokens
  1386. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1387. special_vocab.add_to_gguf(self.gguf_writer)
  1388. def write_tensors(self):
  1389. tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
  1390. tensors = dict(self.get_tensors())
  1391. for name, data_torch in tensors.items():
  1392. # we are only using BERT for embeddings so we don't need the pooling layer
  1393. if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
  1394. continue # we don't need these
  1395. # map tensor names
  1396. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1397. if new_name is None:
  1398. print(f"Can not map tensor {name!r}")
  1399. sys.exit()
  1400. data = data_torch.squeeze().numpy()
  1401. n_dims = len(data.shape)
  1402. new_dtype: type[np.floating[Any]]
  1403. if (
  1404. self.ftype == 1 and name.endswith(".weight") and n_dims == 2
  1405. and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
  1406. ):
  1407. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1408. new_dtype = np.float16
  1409. else:
  1410. # if f32 desired, convert any float16 to float32
  1411. new_dtype = np.float32
  1412. print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
  1413. if data.dtype != new_dtype:
  1414. data = data.astype(new_dtype)
  1415. self.gguf_writer.add_tensor(new_name, data)
  1416. class NomicBertModel(BertModel):
  1417. def __init__(self, *args, **kwargs):
  1418. super().__init__(*args, **kwargs)
  1419. # the HF config claims n_ctx=8192, but it uses RoPE scaling
  1420. self.hparams["n_ctx"] = 2048
  1421. # SwigLU activation
  1422. assert self.hparams["activation_function"] == "swiglu"
  1423. # this doesn't do anything in the HF version
  1424. assert self.hparams["causal"] is False
  1425. # no bias tensors
  1426. assert self.hparams["qkv_proj_bias"] is False
  1427. assert self.hparams["mlp_fc1_bias"] is False
  1428. assert self.hparams["mlp_fc2_bias"] is False
  1429. # norm at end of layer
  1430. assert self.hparams["prenorm"] is False
  1431. # standard RoPE
  1432. assert self.hparams["rotary_emb_fraction"] == 1.0
  1433. assert self.hparams["rotary_emb_interleaved"] is False
  1434. assert self.hparams["rotary_emb_scale_base"] is None
  1435. def set_gguf_parameters(self):
  1436. super().set_gguf_parameters()
  1437. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  1438. def get_tensors(self):
  1439. assert self.vocab_size is not None
  1440. for name, data in super().get_tensors():
  1441. # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
  1442. if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
  1443. rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
  1444. assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
  1445. data = data[:self.vocab_size, :]
  1446. yield name, data
  1447. class GemmaModel(Model):
  1448. def set_vocab(self):
  1449. self._set_vocab_sentencepiece()
  1450. def set_gguf_parameters(self):
  1451. hparams = self.hparams
  1452. block_count = hparams["num_hidden_layers"]
  1453. self.gguf_writer.add_name(self.dir_model.name)
  1454. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  1455. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1456. self.gguf_writer.add_block_count(block_count)
  1457. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1458. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1459. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
  1460. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1461. self.gguf_writer.add_key_length(hparams["head_dim"])
  1462. self.gguf_writer.add_value_length(hparams["head_dim"])
  1463. self.gguf_writer.add_file_type(self.ftype)
  1464. def write_tensors(self):
  1465. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1466. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1467. for name, data_torch in self.get_tensors():
  1468. old_dtype = data_torch.dtype
  1469. # convert any unsupported data types to float32
  1470. if data_torch.dtype not in (torch.float16, torch.float32):
  1471. data_torch = data_torch.to(torch.float32)
  1472. # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
  1473. if name.endswith("norm.weight"):
  1474. data_torch = data_torch + 1
  1475. data = data_torch.squeeze().numpy()
  1476. # map tensor names
  1477. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1478. if new_name is None:
  1479. print(f"Can not map tensor {name!r}")
  1480. sys.exit()
  1481. n_dims = len(data.shape)
  1482. data_dtype = data.dtype
  1483. data = data.astype(np.float32)
  1484. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1485. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1486. data = data.astype(np.float16)
  1487. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1488. self.gguf_writer.add_tensor(new_name, data)
  1489. ###### CONVERSION LOGIC ######
  1490. def parse_args() -> argparse.Namespace:
  1491. parser = argparse.ArgumentParser(
  1492. description="Convert a huggingface model to a GGML compatible file")
  1493. parser.add_argument(
  1494. "--vocab-only", action="store_true",
  1495. help="extract only the vocab",
  1496. )
  1497. parser.add_argument(
  1498. "--awq-path", type=Path, default=None,
  1499. help="Path to scale awq cache file")
  1500. parser.add_argument(
  1501. "--outfile", type=Path,
  1502. help="path to write to; default: based on input",
  1503. )
  1504. parser.add_argument(
  1505. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  1506. help="output format - use f32 for float32, f16 for float16",
  1507. )
  1508. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  1509. parser.add_argument(
  1510. "model", type=Path,
  1511. help="directory containing model file",
  1512. )
  1513. return parser.parse_args()
  1514. def main() -> None:
  1515. args = parse_args()
  1516. dir_model = args.model
  1517. if args.awq_path:
  1518. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1519. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1520. tmp_model_path = args.model / "weighted_model"
  1521. dir_model = tmp_model_path
  1522. if tmp_model_path.is_dir():
  1523. print(f"{tmp_model_path} exists as a weighted model.")
  1524. else:
  1525. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1526. print("Saving new weighted model ...")
  1527. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1528. print(f"Saved weighted model at {tmp_model_path}.")
  1529. if not dir_model.is_dir():
  1530. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  1531. sys.exit(1)
  1532. ftype_map = {
  1533. "f32": gguf.GGMLQuantizationType.F32,
  1534. "f16": gguf.GGMLQuantizationType.F16,
  1535. }
  1536. if args.outfile is not None:
  1537. fname_out = args.outfile
  1538. else:
  1539. # output in the same directory as the model by default
  1540. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1541. print(f"Loading model: {dir_model.name}")
  1542. hparams = Model.load_hparams(dir_model)
  1543. with torch.inference_mode():
  1544. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1545. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1546. print("Set model parameters")
  1547. model_instance.set_gguf_parameters()
  1548. print("Set model tokenizer")
  1549. model_instance.set_vocab()
  1550. if args.vocab_only:
  1551. print(f"Exporting model vocab to '{fname_out}'")
  1552. model_instance.write_vocab()
  1553. else:
  1554. print(f"Exporting model to '{fname_out}'")
  1555. model_instance.write()
  1556. print(f"Model successfully exported to '{fname_out}'")
  1557. if __name__ == '__main__':
  1558. main()