convert-hf-to-gguf.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. from convert import HfVocab
  20. ###### MODEL DEFINITIONS ######
  21. class SentencePieceTokenTypes(IntEnum):
  22. NORMAL = 1
  23. UNKNOWN = 2
  24. CONTROL = 3
  25. USER_DEFINED = 4
  26. UNUSED = 5
  27. BYTE = 6
  28. class Model:
  29. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  30. self.dir_model = dir_model
  31. self.ftype = ftype
  32. self.fname_out = fname_out
  33. self.is_big_endian = is_big_endian
  34. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  35. self.is_safetensors = self._is_model_safetensors()
  36. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  37. self.part_names = self._get_part_names()
  38. self.hparams = Model.load_hparams(self.dir_model)
  39. self.model_arch = self._get_model_architecture()
  40. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  41. self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
  42. def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
  43. key = next((k for k in keys if k in self.hparams), None)
  44. if key is not None:
  45. return self.hparams[key]
  46. if optional:
  47. return None
  48. raise KeyError(f"could not find any of: {keys}")
  49. def set_vocab(self):
  50. self._set_vocab_gpt2()
  51. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  52. for part_name in self.part_names:
  53. print(f"gguf: loading model part '{part_name}'")
  54. ctx: ContextManager[Any]
  55. if self.is_safetensors:
  56. from safetensors import safe_open
  57. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  58. else:
  59. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  60. with ctx as model_part:
  61. for name in model_part.keys():
  62. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  63. yield name, data
  64. def set_gguf_parameters(self):
  65. self.gguf_writer.add_name(self.dir_model.name)
  66. self.gguf_writer.add_block_count(self.block_count)
  67. if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
  68. self.gguf_writer.add_context_length(n_ctx)
  69. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  70. self.gguf_writer.add_embedding_length(n_embd)
  71. if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
  72. self.gguf_writer.add_feed_forward_length(n_ff)
  73. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  74. self.gguf_writer.add_head_count(n_head)
  75. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  76. self.gguf_writer.add_head_count_kv(n_head_kv)
  77. if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  78. self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
  79. if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None:
  80. self.gguf_writer.add_layer_norm_eps(f_norm_eps)
  81. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  82. self.gguf_writer.add_expert_count(n_experts)
  83. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  84. self.gguf_writer.add_expert_used_count(n_experts_used)
  85. self.gguf_writer.add_file_type(self.ftype)
  86. def write_tensors(self):
  87. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  88. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  89. for name, data_torch in self.get_tensors():
  90. # we don't need these
  91. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  92. continue
  93. old_dtype = data_torch.dtype
  94. # convert any unsupported data types to float32
  95. if data_torch.dtype not in (torch.float16, torch.float32):
  96. data_torch = data_torch.to(torch.float32)
  97. data = data_torch.squeeze().numpy()
  98. # map tensor names
  99. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  100. if new_name is None:
  101. print(f"Can not map tensor {name!r}")
  102. sys.exit()
  103. n_dims = len(data.shape)
  104. data_dtype = data.dtype
  105. # if f32 desired, convert any float16 to float32
  106. if self.ftype == 0 and data_dtype == np.float16:
  107. data = data.astype(np.float32)
  108. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  109. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  110. data = data.astype(np.float32)
  111. # if f16 desired, convert any float32 2-dim weight tensors to float16
  112. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  113. data = data.astype(np.float16)
  114. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  115. self.gguf_writer.add_tensor(new_name, data)
  116. def write(self):
  117. self.write_tensors()
  118. self.gguf_writer.write_header_to_file()
  119. self.gguf_writer.write_kv_data_to_file()
  120. self.gguf_writer.write_tensors_to_file()
  121. self.gguf_writer.close()
  122. def write_vocab(self):
  123. self.gguf_writer.write_header_to_file()
  124. self.gguf_writer.write_kv_data_to_file()
  125. self.gguf_writer.close()
  126. @staticmethod
  127. def count_model_parts(dir_model: Path, prefix: str) -> int:
  128. num_parts = 0
  129. for filename in os.listdir(dir_model):
  130. if filename.endswith(prefix):
  131. num_parts += 1
  132. return num_parts
  133. @staticmethod
  134. def load_hparams(dir_model):
  135. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  136. return json.load(f)
  137. @staticmethod
  138. def from_model_architecture(model_architecture):
  139. if model_architecture == "GPTNeoXForCausalLM":
  140. return GPTNeoXModel
  141. if model_architecture == "BloomForCausalLM":
  142. return BloomModel
  143. if model_architecture == "MPTForCausalLM":
  144. return MPTModel
  145. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  146. return BaichuanModel
  147. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  148. return FalconModel
  149. if model_architecture == "GPTBigCodeForCausalLM":
  150. return StarCoderModel
  151. if model_architecture == "GPTRefactForCausalLM":
  152. return RefactModel
  153. if model_architecture == "PersimmonForCausalLM":
  154. return PersimmonModel
  155. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  156. return StableLMModel
  157. if model_architecture == "QWenLMHeadModel":
  158. return QwenModel
  159. if model_architecture == "Qwen2ForCausalLM":
  160. return Model
  161. if model_architecture == "MixtralForCausalLM":
  162. return MixtralModel
  163. if model_architecture == "GPT2LMHeadModel":
  164. return GPT2Model
  165. if model_architecture == "PhiForCausalLM":
  166. return Phi2Model
  167. if model_architecture == "PlamoForCausalLM":
  168. return PlamoModel
  169. if model_architecture == "CodeShellForCausalLM":
  170. return CodeShellModel
  171. if model_architecture == "OrionForCausalLM":
  172. return OrionModel
  173. if model_architecture == "InternLM2ForCausalLM":
  174. return InternLM2Model
  175. if model_architecture == "MiniCPMForCausalLM":
  176. return MiniCPMModel
  177. if model_architecture == "BertModel":
  178. return BertModel
  179. if model_architecture == "NomicBertModel":
  180. return NomicBertModel
  181. return Model
  182. def _is_model_safetensors(self) -> bool:
  183. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  184. def _get_part_names(self):
  185. if self.is_safetensors:
  186. if self.num_parts == 1: # there's only one .safetensors file
  187. return ("model.safetensors",)
  188. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  189. if self.num_parts == 1: # there's only one .bin file
  190. return ("pytorch_model.bin",)
  191. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  192. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  193. arch = self.hparams["architectures"][0]
  194. if arch == "GPTNeoXForCausalLM":
  195. return gguf.MODEL_ARCH.GPTNEOX
  196. if arch == "BloomForCausalLM":
  197. return gguf.MODEL_ARCH.BLOOM
  198. if arch == "MPTForCausalLM":
  199. return gguf.MODEL_ARCH.MPT
  200. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  201. return gguf.MODEL_ARCH.BAICHUAN
  202. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  203. return gguf.MODEL_ARCH.FALCON
  204. if arch == "GPTBigCodeForCausalLM":
  205. return gguf.MODEL_ARCH.STARCODER
  206. if arch == "GPTRefactForCausalLM":
  207. return gguf.MODEL_ARCH.REFACT
  208. if arch == "PersimmonForCausalLM":
  209. return gguf.MODEL_ARCH.PERSIMMON
  210. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  211. return gguf.MODEL_ARCH.STABLELM
  212. if arch == "QWenLMHeadModel":
  213. return gguf.MODEL_ARCH.QWEN
  214. if arch == "Qwen2ForCausalLM":
  215. return gguf.MODEL_ARCH.QWEN2
  216. if arch == "MixtralForCausalLM":
  217. return gguf.MODEL_ARCH.LLAMA
  218. if arch == "GPT2LMHeadModel":
  219. return gguf.MODEL_ARCH.GPT2
  220. if arch == "PhiForCausalLM":
  221. return gguf.MODEL_ARCH.PHI2
  222. if arch == "PlamoForCausalLM":
  223. return gguf.MODEL_ARCH.PLAMO
  224. if arch == "CodeShellForCausalLM":
  225. return gguf.MODEL_ARCH.CODESHELL
  226. if arch == "OrionForCausalLM":
  227. return gguf.MODEL_ARCH.ORION
  228. if arch == "InternLM2ForCausalLM":
  229. return gguf.MODEL_ARCH.INTERNLM2
  230. if arch == "MiniCPMForCausalLM":
  231. return gguf.MODEL_ARCH.MINICPM
  232. if arch == "BertModel":
  233. return gguf.MODEL_ARCH.BERT
  234. if arch == "NomicBertModel":
  235. return gguf.MODEL_ARCH.NOMIC_BERT
  236. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  237. def _set_vocab_gpt2(self):
  238. dir_model = self.dir_model
  239. hparams = self.hparams
  240. tokens: list[bytearray] = []
  241. toktypes: list[int] = []
  242. from transformers import AutoTokenizer
  243. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  244. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  245. assert max(tokenizer.vocab.values()) < vocab_size
  246. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  247. added_vocab = tokenizer.get_added_vocab()
  248. for i in range(vocab_size):
  249. if i not in reverse_vocab:
  250. pad_token = f"[PAD{i}]".encode('utf-8')
  251. tokens.append(bytearray(pad_token))
  252. toktypes.append(gguf.TokenType.USER_DEFINED)
  253. elif reverse_vocab[i] in added_vocab:
  254. tokens.append(reverse_vocab[i])
  255. if tokenizer.added_tokens_decoder[i].special:
  256. toktypes.append(gguf.TokenType.CONTROL)
  257. else:
  258. toktypes.append(gguf.TokenType.USER_DEFINED)
  259. else:
  260. tokens.append(reverse_vocab[i])
  261. toktypes.append(gguf.TokenType.NORMAL)
  262. self.gguf_writer.add_tokenizer_model("gpt2")
  263. self.gguf_writer.add_token_list(tokens)
  264. self.gguf_writer.add_token_types(toktypes)
  265. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  266. special_vocab.add_to_gguf(self.gguf_writer)
  267. def _set_vocab_qwen(self):
  268. dir_model = self.dir_model
  269. hparams = self.hparams
  270. tokens: list[bytearray] = []
  271. toktypes: list[int] = []
  272. from transformers import AutoTokenizer
  273. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  274. vocab_size = hparams["vocab_size"]
  275. assert max(tokenizer.get_vocab().values()) < vocab_size
  276. merges = []
  277. vocab = {}
  278. mergeable_ranks = tokenizer.mergeable_ranks
  279. for token, rank in mergeable_ranks.items():
  280. vocab[QwenModel.token_bytes_to_string(token)] = rank
  281. if len(token) == 1:
  282. continue
  283. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  284. assert len(merged) == 2
  285. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  286. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  287. added_vocab = tokenizer.special_tokens
  288. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
  289. for i in range(vocab_size):
  290. if i not in reverse_vocab:
  291. pad_token = f"[PAD{i}]".encode("utf-8")
  292. tokens.append(bytearray(pad_token))
  293. toktypes.append(gguf.TokenType.USER_DEFINED)
  294. elif reverse_vocab[i] in added_vocab:
  295. tokens.append(reverse_vocab[i])
  296. toktypes.append(gguf.TokenType.CONTROL)
  297. else:
  298. tokens.append(reverse_vocab[i])
  299. toktypes.append(gguf.TokenType.NORMAL)
  300. self.gguf_writer.add_tokenizer_model("gpt2")
  301. self.gguf_writer.add_token_list(tokens)
  302. self.gguf_writer.add_token_types(toktypes)
  303. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  304. special_vocab.merges = merges
  305. # only add special tokens when they were not already loaded from config.json
  306. if len(special_vocab.special_token_ids) == 0:
  307. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  308. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  309. # this one is usually not in config.json anyway
  310. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  311. special_vocab.add_to_gguf(self.gguf_writer)
  312. def _set_vocab_sentencepiece(self):
  313. from sentencepiece import SentencePieceProcessor
  314. tokenizer_path = self.dir_model / 'tokenizer.model'
  315. tokens: list[bytes] = []
  316. scores: list[float] = []
  317. toktypes: list[int] = []
  318. if not tokenizer_path.is_file():
  319. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  320. sys.exit(1)
  321. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  322. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  323. for token_id in range(vocab_size):
  324. piece = tokenizer.id_to_piece(token_id)
  325. text = piece.encode("utf-8")
  326. score = tokenizer.get_score(token_id)
  327. toktype = SentencePieceTokenTypes.NORMAL
  328. if tokenizer.is_unknown(token_id):
  329. toktype = SentencePieceTokenTypes.UNKNOWN
  330. elif tokenizer.is_control(token_id):
  331. toktype = SentencePieceTokenTypes.CONTROL
  332. elif tokenizer.is_unused(token_id):
  333. toktype = SentencePieceTokenTypes.UNUSED
  334. elif tokenizer.is_byte(token_id):
  335. toktype = SentencePieceTokenTypes.BYTE
  336. tokens.append(text)
  337. scores.append(score)
  338. toktypes.append(toktype)
  339. added_tokens_file = self.dir_model / 'added_tokens.json'
  340. if added_tokens_file.is_file():
  341. with open(added_tokens_file, "r", encoding="utf-8") as f:
  342. added_tokens_json = json.load(f)
  343. for key in added_tokens_json:
  344. tokens.append(key.encode("utf-8"))
  345. scores.append(-1000.0)
  346. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  347. self.gguf_writer.add_tokenizer_model("llama")
  348. self.gguf_writer.add_token_list(tokens)
  349. self.gguf_writer.add_token_scores(scores)
  350. self.gguf_writer.add_token_types(toktypes)
  351. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  352. special_vocab.add_to_gguf(self.gguf_writer)
  353. def _set_vocab_hf(self):
  354. path = self.dir_model
  355. added_tokens_path = self.dir_model
  356. vocab = HfVocab(
  357. path, added_tokens_path if added_tokens_path.exists() else None
  358. )
  359. tokens = []
  360. scores = []
  361. toktypes = []
  362. for text, score, toktype in vocab.all_tokens():
  363. tokens.append(text)
  364. scores.append(score)
  365. toktypes.append(toktype)
  366. assert len(tokens) == vocab.vocab_size
  367. self.gguf_writer.add_tokenizer_model("llama")
  368. self.gguf_writer.add_token_list(tokens)
  369. self.gguf_writer.add_token_scores(scores)
  370. self.gguf_writer.add_token_types(toktypes)
  371. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  372. special_vocab.add_to_gguf(self.gguf_writer)
  373. class GPTNeoXModel(Model):
  374. def set_gguf_parameters(self):
  375. block_count = self.hparams["num_hidden_layers"]
  376. self.gguf_writer.add_name(self.dir_model.name)
  377. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  378. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  379. self.gguf_writer.add_block_count(block_count)
  380. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  381. self.gguf_writer.add_rope_dimension_count(
  382. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  383. )
  384. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  385. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  386. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  387. class BloomModel(Model):
  388. def set_gguf_parameters(self):
  389. self.gguf_writer.add_name("Bloom")
  390. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  391. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  392. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  393. self.gguf_writer.add_embedding_length(n_embed)
  394. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  395. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  396. self.gguf_writer.add_head_count(n_head)
  397. self.gguf_writer.add_head_count_kv(n_head)
  398. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  399. self.gguf_writer.add_file_type(self.ftype)
  400. def write_tensors(self):
  401. block_count = self.hparams["n_layer"]
  402. tensors = dict(self.get_tensors())
  403. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  404. has_lm_head = True
  405. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  406. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  407. for name, data_torch in tensors.items():
  408. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  409. has_lm_head = False
  410. name = re.sub(r'transformer\.', '', name)
  411. old_dtype = data_torch.dtype
  412. # convert any unsupported data types to float32
  413. if data_torch.dtype not in (torch.float16, torch.float32):
  414. data_torch = data_torch.to(torch.float32)
  415. data = data_torch.squeeze().numpy()
  416. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  417. # Map bloom-style qkv_linear to gpt-style qkv_linear
  418. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  419. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  420. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  421. data = np.concatenate(
  422. (
  423. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  424. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  425. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  426. ),
  427. axis=0,
  428. )
  429. print("re-format attention.linear_qkv.weight")
  430. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  431. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  432. data = np.concatenate(
  433. (
  434. qkv_bias[:, 0, :].reshape((n_embed,)),
  435. qkv_bias[:, 1, :].reshape((n_embed,)),
  436. qkv_bias[:, 2, :].reshape((n_embed,)),
  437. ),
  438. axis=0,
  439. )
  440. print("re-format attention.linear_qkv.bias")
  441. # map tensor names
  442. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  443. if new_name is None:
  444. print(f"Can not map tensor {name!r}")
  445. sys.exit()
  446. n_dims = len(data.shape)
  447. data_dtype = data.dtype
  448. # if f32 desired, convert any float16 to float32
  449. if self.ftype == 0 and data_dtype == np.float16:
  450. data = data.astype(np.float32)
  451. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  452. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  453. data = data.astype(np.float32)
  454. # if f16 desired, convert any float32 2-dim weight tensors to float16
  455. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  456. data = data.astype(np.float16)
  457. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  458. self.gguf_writer.add_tensor(new_name, data)
  459. if not has_lm_head and name == "word_embeddings.weight":
  460. self.gguf_writer.add_tensor("output.weight", data)
  461. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  462. class MPTModel(Model):
  463. def set_gguf_parameters(self):
  464. block_count = self.hparams["n_layers"]
  465. self.gguf_writer.add_name(self.dir_model.name)
  466. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  467. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  468. self.gguf_writer.add_block_count(block_count)
  469. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  470. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  471. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  472. self.gguf_writer.add_head_count_kv(kv_n_heads)
  473. self.gguf_writer.add_layer_norm_eps(1e-5)
  474. if self.hparams["attn_config"]["clip_qkv"] is not None:
  475. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  476. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  477. def write_tensors(self):
  478. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  479. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  480. for name, data_torch in self.get_tensors():
  481. # we don't need these
  482. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  483. continue
  484. old_dtype = data_torch.dtype
  485. # convert any unsupported data types to float32
  486. if data_torch.dtype not in (torch.float16, torch.float32):
  487. data_torch = data_torch.to(torch.float32)
  488. data = data_torch.squeeze().numpy()
  489. # map tensor names
  490. if "scales" in name:
  491. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  492. if new_name is not None:
  493. new_name = new_name.replace("scales", "act.scales")
  494. else:
  495. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  496. if new_name is None:
  497. print(f"Can not map tensor {name!r}")
  498. sys.exit()
  499. n_dims = len(data.shape)
  500. data_dtype = data.dtype
  501. # if f32 desired, convert any float16 to float32
  502. if self.ftype == 0 and data_dtype == np.float16:
  503. data = data.astype(np.float32)
  504. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  505. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  506. data = data.astype(np.float32)
  507. # if f16 desired, convert any float32 2-dim weight tensors to float16
  508. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  509. data = data.astype(np.float16)
  510. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  511. self.gguf_writer.add_tensor(new_name, data)
  512. # note: MPT output is tied to (same as) wte in original model;
  513. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  514. if new_name == "token_embd.weight":
  515. self.gguf_writer.add_tensor("output.weight", data)
  516. class OrionModel(Model):
  517. def set_vocab(self):
  518. self._set_vocab_sentencepiece()
  519. def set_gguf_parameters(self):
  520. block_count = self.hparams["num_hidden_layers"]
  521. head_count = self.hparams["num_attention_heads"]
  522. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  523. hf_repo = self.hparams.get("_name_or_path", "")
  524. ctx_length = 0
  525. if "max_sequence_length" in self.hparams:
  526. ctx_length = self.hparams["max_sequence_length"]
  527. elif "max_position_embeddings" in self.hparams:
  528. ctx_length = self.hparams["max_position_embeddings"]
  529. elif "model_max_length" in self.hparams:
  530. ctx_length = self.hparams["model_max_length"]
  531. else:
  532. print("gguf: can not find ctx length parameter.")
  533. sys.exit()
  534. self.gguf_writer.add_file_type(self.ftype)
  535. self.gguf_writer.add_name(self.dir_model.name)
  536. self.gguf_writer.add_source_hf_repo(hf_repo)
  537. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  538. self.gguf_writer.add_context_length(ctx_length)
  539. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  540. self.gguf_writer.add_block_count(block_count)
  541. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  542. self.gguf_writer.add_head_count(head_count)
  543. self.gguf_writer.add_head_count_kv(head_count_kv)
  544. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  545. def write_tensors(self):
  546. # Collect tensors from generator object
  547. model_kv = dict(self.get_tensors())
  548. block_count = self.hparams["num_hidden_layers"]
  549. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  550. for name, data_torch in model_kv.items():
  551. # we don't need these
  552. if name.endswith(".rotary_emb.inv_freq"):
  553. continue
  554. old_dtype = data_torch.dtype
  555. # convert any unsupported data types to float32
  556. if data_torch.dtype not in (torch.float16, torch.float32):
  557. data_torch = data_torch.to(torch.float32)
  558. data = data_torch.squeeze().numpy()
  559. # map tensor names
  560. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  561. if new_name is None:
  562. print(f"Can not map tensor {name!r}")
  563. sys.exit()
  564. n_dims = len(data.shape)
  565. data_dtype = data.dtype
  566. # if f32 desired, convert any float16 to float32
  567. if self.ftype == 0 and data_dtype == np.float16:
  568. data = data.astype(np.float32)
  569. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  570. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  571. data = data.astype(np.float32)
  572. # if f16 desired, convert any float32 2-dim weight tensors to float16
  573. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  574. data = data.astype(np.float16)
  575. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  576. self.gguf_writer.add_tensor(new_name, data)
  577. class BaichuanModel(Model):
  578. def set_vocab(self):
  579. self._set_vocab_sentencepiece()
  580. def set_gguf_parameters(self):
  581. block_count = self.hparams["num_hidden_layers"]
  582. head_count = self.hparams["num_attention_heads"]
  583. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  584. hf_repo = self.hparams.get("_name_or_path", "")
  585. ctx_length = 0
  586. if "max_sequence_length" in self.hparams:
  587. ctx_length = self.hparams["max_sequence_length"]
  588. elif "max_position_embeddings" in self.hparams:
  589. ctx_length = self.hparams["max_position_embeddings"]
  590. elif "model_max_length" in self.hparams:
  591. ctx_length = self.hparams["model_max_length"]
  592. else:
  593. print("gguf: can not find ctx length parameter.")
  594. sys.exit()
  595. self.gguf_writer.add_name(self.dir_model.name)
  596. self.gguf_writer.add_source_hf_repo(hf_repo)
  597. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  598. self.gguf_writer.add_context_length(ctx_length)
  599. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  600. self.gguf_writer.add_block_count(block_count)
  601. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  602. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  603. self.gguf_writer.add_head_count(head_count)
  604. self.gguf_writer.add_head_count_kv(head_count_kv)
  605. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  606. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  607. if self.hparams["rope_scaling"].get("type") == "linear":
  608. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  609. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  610. def write_tensors(self):
  611. # Collect tensors from generator object
  612. model_kv = dict(self.get_tensors())
  613. block_count = self.hparams["num_hidden_layers"]
  614. head_count = self.hparams["num_attention_heads"]
  615. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  616. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  617. for i in range(block_count):
  618. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  619. print(f"Unpacking and permuting layer {i}")
  620. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  621. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  622. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  623. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  624. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  625. self._reverse_hf_part(w, 2)
  626. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  627. for name, data_torch in model_kv.items():
  628. # we don't need these
  629. if name.endswith(".rotary_emb.inv_freq"):
  630. continue
  631. old_dtype = data_torch.dtype
  632. # convert any unsupported data types to float32
  633. if data_torch.dtype not in (torch.float16, torch.float32):
  634. data_torch = data_torch.to(torch.float32)
  635. data = data_torch.squeeze().numpy()
  636. # map tensor names
  637. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  638. if new_name is None:
  639. print(f"Can not map tensor {name!r}")
  640. sys.exit()
  641. n_dims = len(data.shape)
  642. data_dtype = data.dtype
  643. # if f32 desired, convert any float16 to float32
  644. if self.ftype == 0 and data_dtype == np.float16:
  645. data = data.astype(np.float32)
  646. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  647. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  648. data = data.astype(np.float32)
  649. # if f16 desired, convert any float32 2-dim weight tensors to float16
  650. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  651. data = data.astype(np.float16)
  652. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  653. self.gguf_writer.add_tensor(new_name, data)
  654. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  655. if n_kv_head is not None and n_head != n_kv_head:
  656. n_head //= n_kv_head
  657. return (
  658. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  659. .swapaxes(1, 2)
  660. .reshape(weights.shape)
  661. )
  662. def _reverse_hf_permute_part(
  663. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  664. ) -> Tensor:
  665. r = weights.shape[0] // 3
  666. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  667. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  668. r = weights.shape[0] // 3
  669. return weights[r * n_part:r * n_part + r, ...]
  670. class FalconModel(Model):
  671. def set_gguf_parameters(self):
  672. block_count = self.hparams.get("num_hidden_layers")
  673. if block_count is None:
  674. block_count = self.hparams["n_layer"] # old name
  675. n_head = self.hparams.get("num_attention_heads")
  676. if n_head is None:
  677. n_head = self.hparams["n_head"] # old name
  678. n_head_kv = self.hparams.get("num_kv_heads")
  679. if n_head_kv is None:
  680. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  681. self.gguf_writer.add_name("Falcon")
  682. self.gguf_writer.add_context_length(2048) # not in config.json
  683. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  684. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  685. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  686. self.gguf_writer.add_block_count(block_count)
  687. self.gguf_writer.add_head_count(n_head)
  688. self.gguf_writer.add_head_count_kv(n_head_kv)
  689. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  690. self.gguf_writer.add_file_type(self.ftype)
  691. def write_tensors(self):
  692. block_count = self.hparams.get("num_hidden_layers")
  693. if block_count is None:
  694. block_count = self.hparams["n_layer"] # old name
  695. n_head = self.hparams.get("num_attention_heads")
  696. if n_head is None:
  697. n_head = self.hparams["n_head"] # old name
  698. n_head_kv = self.hparams.get("num_kv_heads")
  699. if n_head_kv is None:
  700. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  701. head_dim = self.hparams["hidden_size"] // n_head
  702. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  703. for name, data_torch in self.get_tensors():
  704. old_dtype = data_torch.dtype
  705. # convert any unsupported data types to float32
  706. if data_torch.dtype not in (torch.float16, torch.float32):
  707. data_torch = data_torch.to(torch.float32)
  708. # QKV tensor transform
  709. # The original query_key_value tensor contains n_head_kv "kv groups",
  710. # each consisting of n_head/n_head_kv query weights followed by one key
  711. # and one value weight (shared by all query heads in the kv group).
  712. # This layout makes it a big pain to work with in GGML.
  713. # So we rearrange them here,, so that we have n_head query weights
  714. # followed by n_head_kv key weights followed by n_head_kv value weights,
  715. # in contiguous fashion.
  716. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  717. if "query_key_value" in name:
  718. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  719. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  720. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  721. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  722. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  723. data = data_torch.squeeze().numpy()
  724. # map tensor names
  725. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  726. if new_name is None:
  727. print(f"Can not map tensor {name!r}")
  728. sys.exit()
  729. n_dims = len(data.shape)
  730. data_dtype = data.dtype
  731. # if f32 desired, convert any float16 to float32
  732. if self.ftype == 0 and data_dtype == np.float16:
  733. data = data.astype(np.float32)
  734. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  735. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  736. data = data.astype(np.float32)
  737. # if f16 desired, convert any float32 2-dim weight tensors to float16
  738. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  739. data = data.astype(np.float16)
  740. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  741. self.gguf_writer.add_tensor(new_name, data)
  742. class StarCoderModel(Model):
  743. def set_gguf_parameters(self):
  744. block_count = self.hparams["n_layer"]
  745. self.gguf_writer.add_name("StarCoder")
  746. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  747. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  748. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  749. self.gguf_writer.add_block_count(block_count)
  750. self.gguf_writer.add_head_count(self.hparams["n_head"])
  751. self.gguf_writer.add_head_count_kv(1)
  752. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  753. self.gguf_writer.add_file_type(self.ftype)
  754. class RefactModel(Model):
  755. def set_gguf_parameters(self):
  756. hidden_dim = self.hparams["n_embd"]
  757. inner_dim = 4 * hidden_dim
  758. hidden_dim = int(2 * inner_dim / 3)
  759. multiple_of = 256
  760. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  761. block_count = self.hparams["n_layer"]
  762. self.gguf_writer.add_name("Refact")
  763. # refact uses Alibi. So this is from config.json which might be used by training.
  764. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  765. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  766. self.gguf_writer.add_feed_forward_length(ff_dim)
  767. self.gguf_writer.add_block_count(block_count)
  768. self.gguf_writer.add_head_count(self.hparams["n_head"])
  769. self.gguf_writer.add_head_count_kv(1)
  770. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  771. self.gguf_writer.add_file_type(self.ftype)
  772. def write_tensors(self):
  773. hidden_dim = self.hparams["n_embd"]
  774. inner_dim = 4 * hidden_dim
  775. hidden_dim = int(2 * inner_dim / 3)
  776. multiple_of = 256
  777. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  778. n_head = self.hparams["n_head"]
  779. n_head_kv = 1
  780. head_dim = self.hparams["n_embd"] // n_head
  781. block_count = self.hparams["n_layer"]
  782. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  783. tensors = dict(self.get_tensors())
  784. for i in range(block_count):
  785. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  786. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  787. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  788. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  789. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  790. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  791. del tensors[f"transformer.h.{i}.attn.q.weight"]
  792. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  793. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  794. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  795. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  796. for name, data_torch in tensors.items():
  797. old_dtype = data_torch.dtype
  798. # convert any unsupported data types to float32
  799. if data_torch.dtype not in (torch.float16, torch.float32):
  800. data_torch = data_torch.to(torch.float32)
  801. data = data_torch.squeeze().numpy()
  802. # map tensor names
  803. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  804. if new_name is None:
  805. print(f"Can not map tensor {name!r}")
  806. sys.exit()
  807. n_dims = len(data.shape)
  808. data_dtype = data.dtype
  809. # if f32 desired, convert any float16 to float32
  810. if self.ftype == 0 and data_dtype == np.float16:
  811. data = data.astype(np.float32)
  812. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  813. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  814. data = data.astype(np.float32)
  815. # if f16 desired, convert any float32 2-dim weight tensors to float16
  816. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  817. data = data.astype(np.float16)
  818. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  819. self.gguf_writer.add_tensor(new_name, data)
  820. class PersimmonModel(Model):
  821. def set_gguf_parameters(self):
  822. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  823. head_count = self.hparams["num_attention_heads"]
  824. head_count_kv = head_count
  825. hidden_size = self.hparams["hidden_size"]
  826. self.gguf_writer.add_name('persimmon-8b-chat')
  827. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  828. self.gguf_writer.add_embedding_length(hidden_size)
  829. self.gguf_writer.add_block_count(block_count)
  830. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  831. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  832. # than the head size?
  833. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  834. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  835. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  836. self.gguf_writer.add_head_count(head_count)
  837. self.gguf_writer.add_head_count_kv(head_count_kv)
  838. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  839. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  840. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  841. def set_vocab(self):
  842. self._set_vocab_sentencepiece()
  843. # self.gguf_writer.add_bos_token_id(71013)
  844. # self.gguf_writer.add_eos_token_id(71013)
  845. def write_tensors(self):
  846. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  847. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  848. for name, data_torch in self.get_tensors():
  849. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  850. continue
  851. old_dtype = data_torch.dtype
  852. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  853. data = data_torch.to(torch.float32).squeeze().numpy()
  854. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  855. if new_name is None:
  856. print(f"Can not map tensor {name!r}")
  857. sys.exit()
  858. n_dims = len(data.shape)
  859. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  860. self.gguf_writer.add_tensor(new_name, data)
  861. class StableLMModel(Model):
  862. def set_vocab(self):
  863. if (self.dir_model / "tokenizer.json").is_file():
  864. self._set_vocab_gpt2()
  865. else:
  866. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  867. self._set_vocab_qwen()
  868. def set_gguf_parameters(self):
  869. hparams = self.hparams
  870. block_count = hparams["num_hidden_layers"]
  871. self.gguf_writer.add_name(self.dir_model.name)
  872. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  873. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  874. self.gguf_writer.add_block_count(block_count)
  875. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  876. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  877. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  878. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  879. self.gguf_writer.add_layer_norm_eps(1e-5)
  880. class MixtralModel(Model):
  881. def set_vocab(self):
  882. self._set_vocab_sentencepiece()
  883. class MiniCPMModel(Model):
  884. def set_gguf_parameters(self):
  885. block_count = self.hparams["num_hidden_layers"]
  886. self.gguf_writer.add_name("MiniCPM")
  887. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  888. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  889. self.gguf_writer.add_block_count(block_count)
  890. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  891. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  892. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  893. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  894. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  895. self.gguf_writer.add_file_type(self.ftype)
  896. def set_vocab(self):
  897. self._set_vocab_hf()
  898. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  899. if n_kv_head is not None and n_head != n_kv_head:
  900. n_head //= n_kv_head
  901. return (
  902. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  903. .swapaxes(1, 2)
  904. .reshape(weights.shape)
  905. )
  906. def write_tensors(self):
  907. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  908. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  909. n_head = self.hparams.get("num_attention_heads")
  910. n_kv_head = self.hparams.get("num_key_value_heads")
  911. for name, data_torch in self.get_tensors():
  912. # we don't need these
  913. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  914. continue
  915. old_dtype = data_torch.dtype
  916. # convert any unsupported data types to float32
  917. if data_torch.dtype not in (torch.float16, torch.float32):
  918. data_torch = data_torch.to(torch.float32)
  919. # HF models permute some of the tensors, so we need to undo that
  920. if name.endswith(("q_proj.weight")):
  921. data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
  922. if name.endswith(("k_proj.weight")):
  923. data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
  924. data = data_torch.squeeze().numpy()
  925. # map tensor names
  926. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  927. if new_name is None:
  928. print(f"Can not map tensor {name!r}")
  929. sys.exit()
  930. n_dims = len(data.shape)
  931. data_dtype = data.dtype
  932. # if f32 desired, convert any float16 to float32
  933. if self.ftype == 0 and data_dtype == np.float16:
  934. data = data.astype(np.float32)
  935. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  936. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  937. data = data.astype(np.float32)
  938. # if f16 desired, convert any float32 2-dim weight tensors to float16
  939. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  940. data = data.astype(np.float16)
  941. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  942. self.gguf_writer.add_tensor(new_name, data)
  943. class QwenModel(Model):
  944. @staticmethod
  945. def token_bytes_to_string(b):
  946. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  947. byte_encoder = bytes_to_unicode()
  948. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  949. @staticmethod
  950. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  951. parts = [bytes([b]) for b in token]
  952. while True:
  953. min_idx = None
  954. min_rank = None
  955. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  956. rank = mergeable_ranks.get(pair[0] + pair[1])
  957. if rank is not None and (min_rank is None or rank < min_rank):
  958. min_idx = i
  959. min_rank = rank
  960. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  961. break
  962. assert min_idx is not None
  963. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  964. return parts
  965. def set_vocab(self):
  966. self._set_vocab_qwen()
  967. def set_gguf_parameters(self):
  968. self.gguf_writer.add_name("Qwen")
  969. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  970. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  971. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  972. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  973. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  974. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  975. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  976. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  977. def write_tensors(self):
  978. block_count = self.hparams["num_hidden_layers"]
  979. model_kv = dict(self.get_tensors())
  980. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  981. for name, data_torch in model_kv.items():
  982. # we don't need these
  983. if name.endswith(".rotary_emb.inv_freq"):
  984. continue
  985. old_dtype = data_torch.dtype
  986. # convert any unsupported data types to float32
  987. if data_torch.dtype not in (torch.float16, torch.float32):
  988. data_torch = data_torch.to(torch.float32)
  989. data = data_torch.squeeze().numpy()
  990. # map tensor names
  991. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  992. if new_name is None:
  993. print(f"Can not map tensor {name!r}")
  994. sys.exit()
  995. n_dims = len(data.shape)
  996. data_dtype = data.dtype
  997. # if f32 desired, convert any float16 to float32
  998. if self.ftype == 0 and data_dtype == np.float16:
  999. data = data.astype(np.float32)
  1000. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1001. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1002. data = data.astype(np.float32)
  1003. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1004. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1005. data = data.astype(np.float16)
  1006. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1007. self.gguf_writer.add_tensor(new_name, data)
  1008. class GPT2Model(Model):
  1009. def set_gguf_parameters(self):
  1010. self.gguf_writer.add_name(self.dir_model.name)
  1011. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1012. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  1013. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1014. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1015. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1016. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1017. self.gguf_writer.add_file_type(self.ftype)
  1018. def write_tensors(self):
  1019. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1020. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1021. for name, data_torch in self.get_tensors():
  1022. # we don't need these
  1023. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")):
  1024. continue
  1025. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  1026. data_torch = data_torch.transpose(1, 0)
  1027. old_dtype = data_torch.dtype
  1028. # convert any unsupported data types to float32
  1029. if data_torch.dtype not in (torch.float16, torch.float32):
  1030. data_torch = data_torch.to(torch.float32)
  1031. data = data_torch.squeeze().numpy()
  1032. # map tensor names
  1033. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1034. if new_name is None:
  1035. print(f"Can not map tensor {name!r}")
  1036. sys.exit()
  1037. n_dims = len(data.shape)
  1038. data_dtype = data.dtype
  1039. # if f32 desired, convert any float16 to float32
  1040. if self.ftype == 0 and data_dtype == np.float16:
  1041. data = data.astype(np.float32)
  1042. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1043. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1044. data = data.astype(np.float32)
  1045. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1046. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1047. data = data.astype(np.float16)
  1048. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1049. self.gguf_writer.add_tensor(new_name, data)
  1050. # note: GPT2 output is tied to (same as) wte in original model
  1051. if new_name == "token_embd.weight":
  1052. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1053. self.gguf_writer.add_tensor("output.weight", data)
  1054. class Phi2Model(Model):
  1055. def set_gguf_parameters(self):
  1056. block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
  1057. rot_pct = self.find_hparam(["partial_rotary_factor"])
  1058. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  1059. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  1060. self.gguf_writer.add_name("Phi2")
  1061. self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
  1062. self.gguf_writer.add_embedding_length(n_embd)
  1063. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  1064. self.gguf_writer.add_block_count(block_count)
  1065. self.gguf_writer.add_head_count(n_head)
  1066. self.gguf_writer.add_head_count_kv(n_head)
  1067. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
  1068. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1069. self.gguf_writer.add_file_type(self.ftype)
  1070. self.gguf_writer.add_add_bos_token(False)
  1071. class PlamoModel(Model):
  1072. def set_vocab(self):
  1073. self._set_vocab_sentencepiece()
  1074. def set_gguf_parameters(self):
  1075. hparams = self.hparams
  1076. block_count = hparams["num_hidden_layers"]
  1077. self.gguf_writer.add_name("PLaMo")
  1078. self.gguf_writer.add_context_length(4096) # not in config.json
  1079. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1080. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1081. self.gguf_writer.add_block_count(block_count)
  1082. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1083. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  1084. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  1085. def shuffle_attn_q_weight(self, data_torch):
  1086. assert data_torch.size() == (5120, 5120)
  1087. data_torch = data_torch.reshape(8, 5, 128, 5120)
  1088. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  1089. data_torch = torch.reshape(data_torch, (5120, 5120))
  1090. return data_torch
  1091. def shuffle_attn_output_weight(self, data_torch):
  1092. assert data_torch.size() == (5120, 5120)
  1093. data_torch = data_torch.reshape(5120, 8, 5, 128)
  1094. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  1095. data_torch = torch.reshape(data_torch, (5120, 5120))
  1096. return data_torch
  1097. def write_tensors(self):
  1098. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  1099. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1100. for name, data_torch in self.get_tensors():
  1101. if "self_attn.rotary_emb.inv_freq" in name:
  1102. continue
  1103. # map tensor names
  1104. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1105. if new_name is None:
  1106. print(f"Can not map tensor {name!r}")
  1107. sys.exit()
  1108. # shuffle for broadcasting of gqa in ggml_mul_mat
  1109. if new_name.endswith("attn_q.weight"):
  1110. data_torch = self.shuffle_attn_q_weight(data_torch)
  1111. elif new_name.endswith("attn_output.weight"):
  1112. data_torch = self.shuffle_attn_output_weight(data_torch)
  1113. old_dtype = data_torch.dtype
  1114. # convert any unsupported data types to float32
  1115. if data_torch.dtype not in (torch.float16, torch.float32):
  1116. data_torch = data_torch.to(torch.float32)
  1117. data = data_torch.squeeze().numpy()
  1118. n_dims = len(data.shape)
  1119. data_dtype = data.dtype
  1120. # if f32 desired, convert any float16 to float32
  1121. if self.ftype == 0 and data_dtype == np.float16:
  1122. data = data.astype(np.float32)
  1123. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1124. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1125. data = data.astype(np.float32)
  1126. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1127. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1128. data = data.astype(np.float16)
  1129. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1130. self.gguf_writer.add_tensor(new_name, data)
  1131. class CodeShellModel(Model):
  1132. def set_gguf_parameters(self):
  1133. block_count = self.hparams["n_layer"]
  1134. self.gguf_writer.add_name("CodeShell")
  1135. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1136. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1137. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1138. self.gguf_writer.add_block_count(block_count)
  1139. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1140. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1141. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1142. self.gguf_writer.add_file_type(self.ftype)
  1143. self.gguf_writer.add_rope_freq_base(10000.0)
  1144. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1145. self.gguf_writer.add_rope_scaling_factor(1.0)
  1146. def write_tensors(self):
  1147. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1148. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1149. tensors = dict(self.get_tensors())
  1150. has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
  1151. for name, data_torch in tensors.items():
  1152. # we don't need these
  1153. if name.endswith((".attn.rotary_emb.inv_freq")):
  1154. continue
  1155. old_dtype = data_torch.dtype
  1156. # convert any unsupported data types to float32
  1157. if data_torch.dtype not in (torch.float16, torch.float32):
  1158. data_torch = data_torch.to(torch.float32)
  1159. data = data_torch.squeeze().numpy()
  1160. # map tensor names
  1161. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1162. if new_name is None:
  1163. print(f"Can not map tensor {name!r}")
  1164. sys.exit()
  1165. n_dims = len(data.shape)
  1166. data_dtype = data.dtype
  1167. # if f32 desired, convert any float16 to float32
  1168. if self.ftype == 0 and data_dtype == np.float16:
  1169. data = data.astype(np.float32)
  1170. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1171. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1172. data = data.astype(np.float32)
  1173. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1174. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1175. data = data.astype(np.float16)
  1176. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1177. self.gguf_writer.add_tensor(new_name, data)
  1178. if not has_lm_head and name == "transformer.wte.weight":
  1179. self.gguf_writer.add_tensor("output.weight", data)
  1180. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  1181. class InternLM2Model(Model):
  1182. def set_vocab(self):
  1183. # (TODO): Is there a better way?
  1184. # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
  1185. # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
  1186. # recognized as an empty string in C++.
  1187. from sentencepiece import SentencePieceProcessor
  1188. from sentencepiece import sentencepiece_model_pb2 as model
  1189. tokenizer_path = self.dir_model / 'tokenizer.model'
  1190. tokens: list[bytes] = []
  1191. scores: list[float] = []
  1192. toktypes: list[int] = []
  1193. if not tokenizer_path.is_file():
  1194. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  1195. sys.exit(1)
  1196. sentencepiece_model = model.ModelProto()
  1197. sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
  1198. add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
  1199. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  1200. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1201. for token_id in range(vocab_size):
  1202. piece = tokenizer.id_to_piece(token_id)
  1203. text = piece.encode("utf-8")
  1204. score = tokenizer.get_score(token_id)
  1205. if text == b"\x00":
  1206. # (TODO): fixme
  1207. # Hack here and replace the \x00 characters.
  1208. print(f"InternLM2 convert token '{text}' to '🐉'!")
  1209. text = "🐉"
  1210. toktype = SentencePieceTokenTypes.NORMAL
  1211. if tokenizer.is_unknown(token_id):
  1212. toktype = SentencePieceTokenTypes.UNKNOWN
  1213. elif tokenizer.is_control(token_id):
  1214. toktype = SentencePieceTokenTypes.CONTROL
  1215. elif tokenizer.is_unused(token_id):
  1216. toktype = SentencePieceTokenTypes.UNUSED
  1217. elif tokenizer.is_byte(token_id):
  1218. toktype = SentencePieceTokenTypes.BYTE
  1219. tokens.append(text)
  1220. scores.append(score)
  1221. toktypes.append(toktype)
  1222. added_tokens_file = self.dir_model / 'added_tokens.json'
  1223. if added_tokens_file.is_file():
  1224. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1225. added_tokens_json = json.load(f)
  1226. for key in added_tokens_json:
  1227. tokens.append(key.encode("utf-8"))
  1228. scores.append(-1000.0)
  1229. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  1230. self.gguf_writer.add_tokenizer_model("llama")
  1231. self.gguf_writer.add_token_list(tokens)
  1232. self.gguf_writer.add_token_scores(scores)
  1233. self.gguf_writer.add_token_types(toktypes)
  1234. self.gguf_writer.add_add_space_prefix(add_prefix)
  1235. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1236. old_eos = special_vocab.special_token_ids["eos"]
  1237. if "chat" in os.path.basename(self.dir_model.absolute()):
  1238. # For the chat model, we replace the eos with '<|im_end|>'.
  1239. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
  1240. print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
  1241. in chat mode so that the conversation can end normally.")
  1242. special_vocab.add_to_gguf(self.gguf_writer)
  1243. def _try_get_sft_eos(self, tokenizer):
  1244. unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
  1245. im_end_list = tokenizer.encode('<|im_end|>')
  1246. assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
  1247. if len(unused_145_list) == 1:
  1248. eos_token = unused_145_list[0]
  1249. if len(im_end_list) == 1:
  1250. eos_token = im_end_list[0]
  1251. return eos_token
  1252. def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
  1253. if n_head_kv is not None and n_head != n_head_kv:
  1254. n_head = n_head_kv
  1255. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1256. .swapaxes(1, 2)
  1257. .reshape(weights.shape))
  1258. def set_gguf_parameters(self):
  1259. self.gguf_writer.add_name("InternLM2")
  1260. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1261. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1262. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1263. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1264. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  1265. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1266. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1267. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1268. def post_write_tensors(self, tensor_map, name, data_torch):
  1269. old_dtype = data_torch.dtype
  1270. # convert any unsupported data types to float32
  1271. if data_torch.dtype not in (torch.float16, torch.float32):
  1272. data_torch = data_torch.to(torch.float32)
  1273. data = data_torch.squeeze().numpy()
  1274. # map tensor names
  1275. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1276. if new_name is None:
  1277. print(f"Can not map tensor {name!r}")
  1278. sys.exit()
  1279. n_dims = len(data.shape)
  1280. data_dtype = data.dtype
  1281. # if f32 desired, convert any float16 to float32
  1282. if self.ftype == 0 and data_dtype == np.float16:
  1283. data = data.astype(np.float32)
  1284. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1285. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1286. data = data.astype(np.float32)
  1287. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1288. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1289. data = data.astype(np.float16)
  1290. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1291. self.gguf_writer.add_tensor(new_name, data)
  1292. def write_tensors(self):
  1293. from einops import rearrange
  1294. num_heads = self.hparams.get("num_attention_heads")
  1295. num_kv_heads = self.hparams.get("num_key_value_heads")
  1296. hidden_size = self.hparams.get("hidden_size")
  1297. q_per_kv = num_heads // num_kv_heads
  1298. head_dim = hidden_size // num_heads
  1299. num_groups = num_heads // q_per_kv
  1300. block_count = self.hparams["num_hidden_layers"]
  1301. model_kv = dict(self.get_tensors())
  1302. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1303. qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
  1304. for name, data_torch in model_kv.items():
  1305. # we don't need these
  1306. if name.endswith(".rotary_emb.inv_freq"):
  1307. continue
  1308. if re.match(qkv_pattern, name):
  1309. bid = re.findall(qkv_pattern, name)[0]
  1310. qkv = data_torch
  1311. qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
  1312. q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
  1313. # The model weights of q and k equire additional reshape.
  1314. q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
  1315. k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
  1316. v = rearrange(v, " o g n i -> o (g n i)").T
  1317. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
  1318. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
  1319. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
  1320. else:
  1321. self.post_write_tensors(tensor_map, name, data_torch)
  1322. class BertModel(Model):
  1323. def __init__(self, *args, **kwargs):
  1324. super().__init__(*args, **kwargs)
  1325. self.vocab_size = None
  1326. def set_gguf_parameters(self):
  1327. super().set_gguf_parameters()
  1328. self.gguf_writer.add_causal_attention(False)
  1329. # get pooling path
  1330. with open(self.dir_model / "modules.json", encoding="utf-8") as f:
  1331. modules = json.load(f)
  1332. pooling_path = None
  1333. for mod in modules:
  1334. if mod["type"] == "sentence_transformers.models.Pooling":
  1335. pooling_path = mod["path"]
  1336. break
  1337. # get pooling type
  1338. pooling_type = gguf.PoolingType.NONE
  1339. if pooling_path is not None:
  1340. with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
  1341. pooling = json.load(f)
  1342. if pooling["pooling_mode_mean_tokens"]:
  1343. pooling_type = gguf.PoolingType.MEAN
  1344. elif pooling["pooling_mode_cls_token"]:
  1345. pooling_type = gguf.PoolingType.CLS
  1346. else:
  1347. raise NotImplementedError("Only MEAN and CLS pooling types supported")
  1348. self.gguf_writer.add_pooling_type(pooling_type.value)
  1349. def set_vocab(self):
  1350. path = self.dir_model
  1351. added_tokens_path = self.dir_model if self.dir_model.exists() else None
  1352. # use huggingface vocab to get all tokens
  1353. vocab = HfVocab(path, added_tokens_path)
  1354. tokens, scores, toktypes = zip(*vocab.all_tokens())
  1355. assert len(tokens) == vocab.vocab_size
  1356. self.vocab_size = vocab.vocab_size
  1357. # we need this to validate the size of the token_type embeddings
  1358. # though currently we are passing all zeros to the token_type embeddings
  1359. n_token_types = len(set(toktypes))
  1360. self.gguf_writer.add_token_type_count(n_token_types)
  1361. # convert to phantom space vocab
  1362. def phantom(tok, typ):
  1363. if tok.startswith(b"[") and tok.endswith(b"]"):
  1364. return tok
  1365. if tok.startswith(b"##"):
  1366. return tok[2:]
  1367. return b"\xe2\x96\x81" + tok
  1368. tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
  1369. # set up bos and eos tokens (cls and sep)
  1370. self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
  1371. self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
  1372. # add vocab to gguf
  1373. self.gguf_writer.add_tokenizer_model("bert")
  1374. self.gguf_writer.add_token_list(tokens)
  1375. self.gguf_writer.add_token_scores(scores)
  1376. self.gguf_writer.add_token_types(toktypes)
  1377. # handle special tokens
  1378. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1379. special_vocab.add_to_gguf(self.gguf_writer)
  1380. def write_tensors(self):
  1381. tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
  1382. tensors = dict(self.get_tensors())
  1383. for name, data_torch in tensors.items():
  1384. # we are only using BERT for embeddings so we don't need the pooling layer
  1385. if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
  1386. continue # we don't need these
  1387. # map tensor names
  1388. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1389. if new_name is None:
  1390. print(f"Can not map tensor {name!r}")
  1391. sys.exit()
  1392. data = data_torch.squeeze().numpy()
  1393. n_dims = len(data.shape)
  1394. new_dtype: type[np.floating[Any]]
  1395. if (
  1396. self.ftype == 1 and name.endswith(".weight") and n_dims == 2
  1397. and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
  1398. ):
  1399. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1400. new_dtype = np.float16
  1401. else:
  1402. # if f32 desired, convert any float16 to float32
  1403. new_dtype = np.float32
  1404. print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
  1405. if data.dtype != new_dtype:
  1406. data = data.astype(new_dtype)
  1407. self.gguf_writer.add_tensor(new_name, data)
  1408. class NomicBertModel(BertModel):
  1409. def __init__(self, *args, **kwargs):
  1410. super().__init__(*args, **kwargs)
  1411. # the HF config claims n_ctx=8192, but it uses RoPE scaling
  1412. self.hparams["n_ctx"] = 2048
  1413. # SwigLU activation
  1414. assert self.hparams["activation_function"] == "swiglu"
  1415. # this doesn't do anything in the HF version
  1416. assert self.hparams["causal"] is False
  1417. # no bias tensors
  1418. assert self.hparams["qkv_proj_bias"] is False
  1419. assert self.hparams["mlp_fc1_bias"] is False
  1420. assert self.hparams["mlp_fc2_bias"] is False
  1421. # norm at end of layer
  1422. assert self.hparams["prenorm"] is False
  1423. # standard RoPE
  1424. assert self.hparams["rotary_emb_fraction"] == 1.0
  1425. assert self.hparams["rotary_emb_interleaved"] is False
  1426. assert self.hparams["rotary_emb_scale_base"] is None
  1427. def set_gguf_parameters(self):
  1428. super().set_gguf_parameters()
  1429. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  1430. def get_tensors(self):
  1431. assert self.vocab_size is not None
  1432. for name, data in super().get_tensors():
  1433. # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
  1434. if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
  1435. rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
  1436. assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
  1437. data = data[:self.vocab_size, :]
  1438. yield name, data
  1439. ###### CONVERSION LOGIC ######
  1440. def parse_args() -> argparse.Namespace:
  1441. parser = argparse.ArgumentParser(
  1442. description="Convert a huggingface model to a GGML compatible file")
  1443. parser.add_argument(
  1444. "--vocab-only", action="store_true",
  1445. help="extract only the vocab",
  1446. )
  1447. parser.add_argument(
  1448. "--awq-path", type=Path, default=None,
  1449. help="Path to scale awq cache file")
  1450. parser.add_argument(
  1451. "--outfile", type=Path,
  1452. help="path to write to; default: based on input",
  1453. )
  1454. parser.add_argument(
  1455. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  1456. help="output format - use f32 for float32, f16 for float16",
  1457. )
  1458. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  1459. parser.add_argument(
  1460. "model", type=Path,
  1461. help="directory containing model file",
  1462. )
  1463. return parser.parse_args()
  1464. def main() -> None:
  1465. args = parse_args()
  1466. dir_model = args.model
  1467. if args.awq_path:
  1468. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1469. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1470. tmp_model_path = args.model / "weighted_model"
  1471. dir_model = tmp_model_path
  1472. if tmp_model_path.is_dir():
  1473. print(f"{tmp_model_path} exists as a weighted model.")
  1474. else:
  1475. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1476. print("Saving new weighted model ...")
  1477. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1478. print(f"Saved weighted model at {tmp_model_path}.")
  1479. if not dir_model.is_dir():
  1480. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  1481. sys.exit(1)
  1482. ftype_map = {
  1483. "f32": gguf.GGMLQuantizationType.F32,
  1484. "f16": gguf.GGMLQuantizationType.F16,
  1485. }
  1486. if args.outfile is not None:
  1487. fname_out = args.outfile
  1488. else:
  1489. # output in the same directory as the model by default
  1490. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1491. print(f"Loading model: {dir_model.name}")
  1492. hparams = Model.load_hparams(dir_model)
  1493. with torch.inference_mode():
  1494. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1495. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1496. print("Set model parameters")
  1497. model_instance.set_gguf_parameters()
  1498. print("Set model tokenizer")
  1499. model_instance.set_vocab()
  1500. if args.vocab_only:
  1501. print(f"Exporting model vocab to '{fname_out}'")
  1502. model_instance.write_vocab()
  1503. else:
  1504. print(f"Exporting model to '{fname_out}'")
  1505. model_instance.write()
  1506. print(f"Model successfully exported to '{fname_out}'")
  1507. if __name__ == '__main__':
  1508. main()