convert-hf-to-gguf.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. from convert import HfVocab
  20. # check for any of the given keys in the dictionary and return the value of the first key found
  21. def get_key_opts(d, keys):
  22. for k in keys:
  23. if k in d:
  24. return d[k]
  25. print(f"Could not find any of {keys}")
  26. sys.exit()
  27. ###### MODEL DEFINITIONS ######
  28. class SentencePieceTokenTypes(IntEnum):
  29. NORMAL = 1
  30. UNKNOWN = 2
  31. CONTROL = 3
  32. USER_DEFINED = 4
  33. UNUSED = 5
  34. BYTE = 6
  35. class Model:
  36. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  37. self.dir_model = dir_model
  38. self.ftype = ftype
  39. self.fname_out = fname_out
  40. self.is_big_endian = is_big_endian
  41. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  42. self.is_safetensors = self._is_model_safetensors()
  43. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  44. self.part_names = self._get_part_names()
  45. self.hparams = Model.load_hparams(self.dir_model)
  46. self.model_arch = self._get_model_architecture()
  47. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  48. def set_vocab(self):
  49. self._set_vocab_gpt2()
  50. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  51. for part_name in self.part_names:
  52. print(f"gguf: loading model part '{part_name}'")
  53. ctx: ContextManager[Any]
  54. if self.is_safetensors:
  55. from safetensors import safe_open
  56. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  57. else:
  58. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  59. with ctx as model_part:
  60. for name in model_part.keys():
  61. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  62. yield name, data
  63. def set_gguf_parameters(self):
  64. self.gguf_writer.add_name(self.dir_model.name)
  65. self.gguf_writer.add_block_count(self.hparams.get(
  66. "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
  67. ))
  68. if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
  69. self.gguf_writer.add_context_length(n_ctx)
  70. if (n_embd := self.hparams.get("hidden_size")) is not None:
  71. self.gguf_writer.add_embedding_length(n_embd)
  72. if (n_ff := self.hparams.get("intermediate_size")) is not None:
  73. self.gguf_writer.add_feed_forward_length(n_ff)
  74. if (n_head := self.hparams.get("num_attention_heads")) is not None:
  75. self.gguf_writer.add_head_count(n_head)
  76. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  77. self.gguf_writer.add_head_count_kv(n_head_kv)
  78. if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  79. self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
  80. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  81. self.gguf_writer.add_expert_count(n_experts)
  82. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  83. self.gguf_writer.add_expert_used_count(n_experts_used)
  84. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  85. def write_tensors(self):
  86. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  87. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  88. for name, data_torch in self.get_tensors():
  89. # we don't need these
  90. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  91. continue
  92. old_dtype = data_torch.dtype
  93. # convert any unsupported data types to float32
  94. if data_torch.dtype not in (torch.float16, torch.float32):
  95. data_torch = data_torch.to(torch.float32)
  96. data = data_torch.squeeze().numpy()
  97. # map tensor names
  98. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  99. if new_name is None:
  100. print(f"Can not map tensor {name!r}")
  101. sys.exit()
  102. n_dims = len(data.shape)
  103. data_dtype = data.dtype
  104. # if f32 desired, convert any float16 to float32
  105. if self.ftype == 0 and data_dtype == np.float16:
  106. data = data.astype(np.float32)
  107. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  108. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  109. data = data.astype(np.float32)
  110. # if f16 desired, convert any float32 2-dim weight tensors to float16
  111. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  112. data = data.astype(np.float16)
  113. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  114. self.gguf_writer.add_tensor(new_name, data)
  115. def write(self):
  116. self.write_tensors()
  117. self.gguf_writer.write_header_to_file()
  118. self.gguf_writer.write_kv_data_to_file()
  119. self.gguf_writer.write_tensors_to_file()
  120. self.gguf_writer.close()
  121. def write_vocab(self):
  122. self.gguf_writer.write_header_to_file()
  123. self.gguf_writer.write_kv_data_to_file()
  124. self.gguf_writer.close()
  125. @staticmethod
  126. def count_model_parts(dir_model: Path, prefix: str) -> int:
  127. num_parts = 0
  128. for filename in os.listdir(dir_model):
  129. if filename.endswith(prefix):
  130. num_parts += 1
  131. return num_parts
  132. @staticmethod
  133. def load_hparams(dir_model):
  134. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  135. return json.load(f)
  136. @staticmethod
  137. def from_model_architecture(model_architecture):
  138. if model_architecture == "GPTNeoXForCausalLM":
  139. return GPTNeoXModel
  140. if model_architecture == "BloomForCausalLM":
  141. return BloomModel
  142. if model_architecture == "MPTForCausalLM":
  143. return MPTModel
  144. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  145. return BaichuanModel
  146. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  147. return FalconModel
  148. if model_architecture == "GPTBigCodeForCausalLM":
  149. return StarCoderModel
  150. if model_architecture == "GPTRefactForCausalLM":
  151. return RefactModel
  152. if model_architecture == "PersimmonForCausalLM":
  153. return PersimmonModel
  154. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  155. return StableLMModel
  156. if model_architecture == "QWenLMHeadModel":
  157. return QwenModel
  158. if model_architecture == "Qwen2ForCausalLM":
  159. return Model
  160. if model_architecture == "MixtralForCausalLM":
  161. return MixtralModel
  162. if model_architecture == "GPT2LMHeadModel":
  163. return GPT2Model
  164. if model_architecture == "PhiForCausalLM":
  165. return Phi2Model
  166. if model_architecture == "PlamoForCausalLM":
  167. return PlamoModel
  168. if model_architecture == "CodeShellForCausalLM":
  169. return CodeShellModel
  170. if model_architecture == "OrionForCausalLM":
  171. return OrionModel
  172. if model_architecture == "InternLM2ForCausalLM":
  173. return InternLM2Model
  174. if model_architecture == "MiniCPMForCausalLM":
  175. return MiniCPMModel
  176. if model_architecture == "BertModel":
  177. return BertModel
  178. return Model
  179. def _is_model_safetensors(self) -> bool:
  180. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  181. def _get_part_names(self):
  182. if self.is_safetensors:
  183. if self.num_parts == 1: # there's only one .safetensors file
  184. return ("model.safetensors",)
  185. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  186. if self.num_parts == 1: # there's only one .bin file
  187. return ("pytorch_model.bin",)
  188. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  189. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  190. arch = self.hparams["architectures"][0]
  191. if arch == "GPTNeoXForCausalLM":
  192. return gguf.MODEL_ARCH.GPTNEOX
  193. if arch == "BloomForCausalLM":
  194. return gguf.MODEL_ARCH.BLOOM
  195. if arch == "MPTForCausalLM":
  196. return gguf.MODEL_ARCH.MPT
  197. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  198. return gguf.MODEL_ARCH.BAICHUAN
  199. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  200. return gguf.MODEL_ARCH.FALCON
  201. if arch == "GPTBigCodeForCausalLM":
  202. return gguf.MODEL_ARCH.STARCODER
  203. if arch == "GPTRefactForCausalLM":
  204. return gguf.MODEL_ARCH.REFACT
  205. if arch == "PersimmonForCausalLM":
  206. return gguf.MODEL_ARCH.PERSIMMON
  207. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  208. return gguf.MODEL_ARCH.STABLELM
  209. if arch == "QWenLMHeadModel":
  210. return gguf.MODEL_ARCH.QWEN
  211. if arch == "Qwen2ForCausalLM":
  212. return gguf.MODEL_ARCH.QWEN2
  213. if arch == "MixtralForCausalLM":
  214. return gguf.MODEL_ARCH.LLAMA
  215. if arch == "GPT2LMHeadModel":
  216. return gguf.MODEL_ARCH.GPT2
  217. if arch == "PhiForCausalLM":
  218. return gguf.MODEL_ARCH.PHI2
  219. if arch == "PlamoForCausalLM":
  220. return gguf.MODEL_ARCH.PLAMO
  221. if arch == "CodeShellForCausalLM":
  222. return gguf.MODEL_ARCH.CODESHELL
  223. if arch == "OrionForCausalLM":
  224. return gguf.MODEL_ARCH.ORION
  225. if arch == "InternLM2ForCausalLM":
  226. return gguf.MODEL_ARCH.INTERNLM2
  227. if arch == "MiniCPMForCausalLM":
  228. return gguf.MODEL_ARCH.MINICPM
  229. if arch == "BertModel":
  230. return gguf.MODEL_ARCH.BERT
  231. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  232. def _set_vocab_gpt2(self):
  233. dir_model = self.dir_model
  234. hparams = self.hparams
  235. tokens: list[bytearray] = []
  236. toktypes: list[int] = []
  237. from transformers import AutoTokenizer
  238. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  239. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  240. assert max(tokenizer.vocab.values()) < vocab_size
  241. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  242. added_vocab = tokenizer.get_added_vocab()
  243. for i in range(vocab_size):
  244. if i not in reverse_vocab:
  245. pad_token = f"[PAD{i}]".encode('utf-8')
  246. tokens.append(bytearray(pad_token))
  247. toktypes.append(gguf.TokenType.USER_DEFINED)
  248. elif reverse_vocab[i] in added_vocab:
  249. tokens.append(reverse_vocab[i])
  250. if tokenizer.added_tokens_decoder[i].special:
  251. toktypes.append(gguf.TokenType.CONTROL)
  252. else:
  253. toktypes.append(gguf.TokenType.USER_DEFINED)
  254. else:
  255. tokens.append(reverse_vocab[i])
  256. toktypes.append(gguf.TokenType.NORMAL)
  257. self.gguf_writer.add_tokenizer_model("gpt2")
  258. self.gguf_writer.add_token_list(tokens)
  259. self.gguf_writer.add_token_types(toktypes)
  260. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  261. special_vocab.add_to_gguf(self.gguf_writer)
  262. def _set_vocab_qwen(self):
  263. dir_model = self.dir_model
  264. hparams = self.hparams
  265. tokens: list[bytearray] = []
  266. toktypes: list[int] = []
  267. from transformers import AutoTokenizer
  268. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  269. vocab_size = hparams["vocab_size"]
  270. assert max(tokenizer.get_vocab().values()) < vocab_size
  271. merges = []
  272. vocab = {}
  273. mergeable_ranks = tokenizer.mergeable_ranks
  274. for token, rank in mergeable_ranks.items():
  275. vocab[QwenModel.token_bytes_to_string(token)] = rank
  276. if len(token) == 1:
  277. continue
  278. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  279. assert len(merged) == 2
  280. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  281. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  282. added_vocab = tokenizer.special_tokens
  283. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
  284. for i in range(vocab_size):
  285. if i not in reverse_vocab:
  286. pad_token = f"[PAD{i}]".encode("utf-8")
  287. tokens.append(bytearray(pad_token))
  288. toktypes.append(gguf.TokenType.USER_DEFINED)
  289. elif reverse_vocab[i] in added_vocab:
  290. tokens.append(reverse_vocab[i])
  291. toktypes.append(gguf.TokenType.CONTROL)
  292. else:
  293. tokens.append(reverse_vocab[i])
  294. toktypes.append(gguf.TokenType.NORMAL)
  295. self.gguf_writer.add_tokenizer_model("gpt2")
  296. self.gguf_writer.add_token_list(tokens)
  297. self.gguf_writer.add_token_types(toktypes)
  298. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  299. special_vocab.merges = merges
  300. # only add special tokens when they were not already loaded from config.json
  301. if len(special_vocab.special_token_ids) == 0:
  302. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  303. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  304. # this one is usually not in config.json anyway
  305. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  306. special_vocab.add_to_gguf(self.gguf_writer)
  307. def _set_vocab_sentencepiece(self):
  308. from sentencepiece import SentencePieceProcessor
  309. tokenizer_path = self.dir_model / 'tokenizer.model'
  310. tokens: list[bytes] = []
  311. scores: list[float] = []
  312. toktypes: list[int] = []
  313. if not tokenizer_path.is_file():
  314. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  315. sys.exit(1)
  316. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  317. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  318. for token_id in range(vocab_size):
  319. piece = tokenizer.id_to_piece(token_id)
  320. text = piece.encode("utf-8")
  321. score = tokenizer.get_score(token_id)
  322. toktype = SentencePieceTokenTypes.NORMAL
  323. if tokenizer.is_unknown(token_id):
  324. toktype = SentencePieceTokenTypes.UNKNOWN
  325. elif tokenizer.is_control(token_id):
  326. toktype = SentencePieceTokenTypes.CONTROL
  327. elif tokenizer.is_unused(token_id):
  328. toktype = SentencePieceTokenTypes.UNUSED
  329. elif tokenizer.is_byte(token_id):
  330. toktype = SentencePieceTokenTypes.BYTE
  331. tokens.append(text)
  332. scores.append(score)
  333. toktypes.append(toktype)
  334. added_tokens_file = self.dir_model / 'added_tokens.json'
  335. if added_tokens_file.is_file():
  336. with open(added_tokens_file, "r", encoding="utf-8") as f:
  337. added_tokens_json = json.load(f)
  338. for key in added_tokens_json:
  339. tokens.append(key.encode("utf-8"))
  340. scores.append(-1000.0)
  341. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  342. self.gguf_writer.add_tokenizer_model("llama")
  343. self.gguf_writer.add_token_list(tokens)
  344. self.gguf_writer.add_token_scores(scores)
  345. self.gguf_writer.add_token_types(toktypes)
  346. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  347. special_vocab.add_to_gguf(self.gguf_writer)
  348. def _set_vocab_hf(self):
  349. path = self.dir_model
  350. added_tokens_path = self.dir_model
  351. vocab = HfVocab(
  352. path, added_tokens_path if added_tokens_path.exists() else None
  353. )
  354. tokens = []
  355. scores = []
  356. toktypes = []
  357. for text, score, toktype in vocab.all_tokens():
  358. tokens.append(text)
  359. scores.append(score)
  360. toktypes.append(toktype)
  361. assert len(tokens) == vocab.vocab_size
  362. self.gguf_writer.add_tokenizer_model("llama")
  363. self.gguf_writer.add_token_list(tokens)
  364. self.gguf_writer.add_token_scores(scores)
  365. self.gguf_writer.add_token_types(toktypes)
  366. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  367. special_vocab.add_to_gguf(self.gguf_writer)
  368. class GPTNeoXModel(Model):
  369. def set_gguf_parameters(self):
  370. block_count = self.hparams["num_hidden_layers"]
  371. self.gguf_writer.add_name(self.dir_model.name)
  372. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  373. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  374. self.gguf_writer.add_block_count(block_count)
  375. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  376. self.gguf_writer.add_rope_dimension_count(
  377. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  378. )
  379. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  380. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  381. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  382. class BloomModel(Model):
  383. def set_gguf_parameters(self):
  384. self.gguf_writer.add_name("Bloom")
  385. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  386. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  387. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  388. self.gguf_writer.add_embedding_length(n_embed)
  389. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  390. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  391. self.gguf_writer.add_head_count(n_head)
  392. self.gguf_writer.add_head_count_kv(n_head)
  393. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  394. self.gguf_writer.add_file_type(self.ftype)
  395. def write_tensors(self):
  396. block_count = self.hparams["n_layer"]
  397. tensors = dict(self.get_tensors())
  398. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  399. has_lm_head = True
  400. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  401. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  402. for name, data_torch in tensors.items():
  403. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  404. has_lm_head = False
  405. name = re.sub(r'transformer\.', '', name)
  406. old_dtype = data_torch.dtype
  407. # convert any unsupported data types to float32
  408. if data_torch.dtype not in (torch.float16, torch.float32):
  409. data_torch = data_torch.to(torch.float32)
  410. data = data_torch.squeeze().numpy()
  411. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  412. # Map bloom-style qkv_linear to gpt-style qkv_linear
  413. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  414. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  415. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  416. data = np.concatenate(
  417. (
  418. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  419. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  420. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  421. ),
  422. axis=0,
  423. )
  424. print("re-format attention.linear_qkv.weight")
  425. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  426. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  427. data = np.concatenate(
  428. (
  429. qkv_bias[:, 0, :].reshape((n_embed,)),
  430. qkv_bias[:, 1, :].reshape((n_embed,)),
  431. qkv_bias[:, 2, :].reshape((n_embed,)),
  432. ),
  433. axis=0,
  434. )
  435. print("re-format attention.linear_qkv.bias")
  436. # map tensor names
  437. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  438. if new_name is None:
  439. print(f"Can not map tensor {name!r}")
  440. sys.exit()
  441. n_dims = len(data.shape)
  442. data_dtype = data.dtype
  443. # if f32 desired, convert any float16 to float32
  444. if self.ftype == 0 and data_dtype == np.float16:
  445. data = data.astype(np.float32)
  446. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  447. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  448. data = data.astype(np.float32)
  449. # if f16 desired, convert any float32 2-dim weight tensors to float16
  450. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  451. data = data.astype(np.float16)
  452. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  453. self.gguf_writer.add_tensor(new_name, data)
  454. if not has_lm_head and name == "word_embeddings.weight":
  455. self.gguf_writer.add_tensor("output.weight", data)
  456. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  457. class MPTModel(Model):
  458. def set_gguf_parameters(self):
  459. block_count = self.hparams["n_layers"]
  460. self.gguf_writer.add_name(self.dir_model.name)
  461. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  462. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  463. self.gguf_writer.add_block_count(block_count)
  464. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  465. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  466. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  467. self.gguf_writer.add_head_count_kv(kv_n_heads)
  468. self.gguf_writer.add_layer_norm_eps(1e-5)
  469. if self.hparams["attn_config"]["clip_qkv"] is not None:
  470. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  471. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  472. def write_tensors(self):
  473. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  474. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  475. for name, data_torch in self.get_tensors():
  476. # we don't need these
  477. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  478. continue
  479. old_dtype = data_torch.dtype
  480. # convert any unsupported data types to float32
  481. if data_torch.dtype not in (torch.float16, torch.float32):
  482. data_torch = data_torch.to(torch.float32)
  483. data = data_torch.squeeze().numpy()
  484. # map tensor names
  485. if "scales" in name:
  486. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  487. if new_name is not None:
  488. new_name = new_name.replace("scales", "act.scales")
  489. else:
  490. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  491. if new_name is None:
  492. print(f"Can not map tensor {name!r}")
  493. sys.exit()
  494. n_dims = len(data.shape)
  495. data_dtype = data.dtype
  496. # if f32 desired, convert any float16 to float32
  497. if self.ftype == 0 and data_dtype == np.float16:
  498. data = data.astype(np.float32)
  499. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  500. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  501. data = data.astype(np.float32)
  502. # if f16 desired, convert any float32 2-dim weight tensors to float16
  503. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  504. data = data.astype(np.float16)
  505. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  506. self.gguf_writer.add_tensor(new_name, data)
  507. # note: MPT output is tied to (same as) wte in original model;
  508. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  509. if new_name == "token_embd.weight":
  510. self.gguf_writer.add_tensor("output.weight", data)
  511. class OrionModel(Model):
  512. def set_vocab(self):
  513. self._set_vocab_sentencepiece()
  514. def set_gguf_parameters(self):
  515. block_count = self.hparams["num_hidden_layers"]
  516. head_count = self.hparams["num_attention_heads"]
  517. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  518. hf_repo = self.hparams.get("_name_or_path", "")
  519. ctx_length = 0
  520. if "max_sequence_length" in self.hparams:
  521. ctx_length = self.hparams["max_sequence_length"]
  522. elif "max_position_embeddings" in self.hparams:
  523. ctx_length = self.hparams["max_position_embeddings"]
  524. elif "model_max_length" in self.hparams:
  525. ctx_length = self.hparams["model_max_length"]
  526. else:
  527. print("gguf: can not find ctx length parameter.")
  528. sys.exit()
  529. self.gguf_writer.add_file_type(self.ftype)
  530. self.gguf_writer.add_name(self.dir_model.name)
  531. self.gguf_writer.add_source_hf_repo(hf_repo)
  532. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  533. self.gguf_writer.add_context_length(ctx_length)
  534. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  535. self.gguf_writer.add_block_count(block_count)
  536. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  537. self.gguf_writer.add_head_count(head_count)
  538. self.gguf_writer.add_head_count_kv(head_count_kv)
  539. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  540. def write_tensors(self):
  541. # Collect tensors from generator object
  542. model_kv = dict(self.get_tensors())
  543. block_count = self.hparams["num_hidden_layers"]
  544. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  545. for name, data_torch in model_kv.items():
  546. # we don't need these
  547. if name.endswith(".rotary_emb.inv_freq"):
  548. continue
  549. old_dtype = data_torch.dtype
  550. # convert any unsupported data types to float32
  551. if data_torch.dtype not in (torch.float16, torch.float32):
  552. data_torch = data_torch.to(torch.float32)
  553. data = data_torch.squeeze().numpy()
  554. # map tensor names
  555. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  556. if new_name is None:
  557. print(f"Can not map tensor {name!r}")
  558. sys.exit()
  559. n_dims = len(data.shape)
  560. data_dtype = data.dtype
  561. # if f32 desired, convert any float16 to float32
  562. if self.ftype == 0 and data_dtype == np.float16:
  563. data = data.astype(np.float32)
  564. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  565. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  566. data = data.astype(np.float32)
  567. # if f16 desired, convert any float32 2-dim weight tensors to float16
  568. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  569. data = data.astype(np.float16)
  570. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  571. self.gguf_writer.add_tensor(new_name, data)
  572. class BaichuanModel(Model):
  573. def set_vocab(self):
  574. self._set_vocab_sentencepiece()
  575. def set_gguf_parameters(self):
  576. block_count = self.hparams["num_hidden_layers"]
  577. head_count = self.hparams["num_attention_heads"]
  578. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  579. hf_repo = self.hparams.get("_name_or_path", "")
  580. ctx_length = 0
  581. if "max_sequence_length" in self.hparams:
  582. ctx_length = self.hparams["max_sequence_length"]
  583. elif "max_position_embeddings" in self.hparams:
  584. ctx_length = self.hparams["max_position_embeddings"]
  585. elif "model_max_length" in self.hparams:
  586. ctx_length = self.hparams["model_max_length"]
  587. else:
  588. print("gguf: can not find ctx length parameter.")
  589. sys.exit()
  590. self.gguf_writer.add_name(self.dir_model.name)
  591. self.gguf_writer.add_source_hf_repo(hf_repo)
  592. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  593. self.gguf_writer.add_context_length(ctx_length)
  594. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  595. self.gguf_writer.add_block_count(block_count)
  596. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  597. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  598. self.gguf_writer.add_head_count(head_count)
  599. self.gguf_writer.add_head_count_kv(head_count_kv)
  600. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  601. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  602. if self.hparams["rope_scaling"].get("type") == "linear":
  603. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  604. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  605. def write_tensors(self):
  606. # Collect tensors from generator object
  607. model_kv = dict(self.get_tensors())
  608. block_count = self.hparams["num_hidden_layers"]
  609. head_count = self.hparams["num_attention_heads"]
  610. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  611. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  612. for i in range(block_count):
  613. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  614. print(f"Unpacking and permuting layer {i}")
  615. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  616. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  617. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  618. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  619. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  620. self._reverse_hf_part(w, 2)
  621. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  622. for name, data_torch in model_kv.items():
  623. # we don't need these
  624. if name.endswith(".rotary_emb.inv_freq"):
  625. continue
  626. old_dtype = data_torch.dtype
  627. # convert any unsupported data types to float32
  628. if data_torch.dtype not in (torch.float16, torch.float32):
  629. data_torch = data_torch.to(torch.float32)
  630. data = data_torch.squeeze().numpy()
  631. # map tensor names
  632. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  633. if new_name is None:
  634. print(f"Can not map tensor {name!r}")
  635. sys.exit()
  636. n_dims = len(data.shape)
  637. data_dtype = data.dtype
  638. # if f32 desired, convert any float16 to float32
  639. if self.ftype == 0 and data_dtype == np.float16:
  640. data = data.astype(np.float32)
  641. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  642. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  643. data = data.astype(np.float32)
  644. # if f16 desired, convert any float32 2-dim weight tensors to float16
  645. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  646. data = data.astype(np.float16)
  647. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  648. self.gguf_writer.add_tensor(new_name, data)
  649. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  650. if n_kv_head is not None and n_head != n_kv_head:
  651. n_head //= n_kv_head
  652. return (
  653. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  654. .swapaxes(1, 2)
  655. .reshape(weights.shape)
  656. )
  657. def _reverse_hf_permute_part(
  658. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  659. ) -> Tensor:
  660. r = weights.shape[0] // 3
  661. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  662. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  663. r = weights.shape[0] // 3
  664. return weights[r * n_part:r * n_part + r, ...]
  665. class FalconModel(Model):
  666. def set_gguf_parameters(self):
  667. block_count = self.hparams.get("num_hidden_layers")
  668. if block_count is None:
  669. block_count = self.hparams["n_layer"] # old name
  670. n_head = self.hparams.get("num_attention_heads")
  671. if n_head is None:
  672. n_head = self.hparams["n_head"] # old name
  673. n_head_kv = self.hparams.get("num_kv_heads")
  674. if n_head_kv is None:
  675. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  676. self.gguf_writer.add_name("Falcon")
  677. self.gguf_writer.add_context_length(2048) # not in config.json
  678. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  679. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  680. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  681. self.gguf_writer.add_block_count(block_count)
  682. self.gguf_writer.add_head_count(n_head)
  683. self.gguf_writer.add_head_count_kv(n_head_kv)
  684. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  685. self.gguf_writer.add_file_type(self.ftype)
  686. def write_tensors(self):
  687. block_count = self.hparams.get("num_hidden_layers")
  688. if block_count is None:
  689. block_count = self.hparams["n_layer"] # old name
  690. n_head = self.hparams.get("num_attention_heads")
  691. if n_head is None:
  692. n_head = self.hparams["n_head"] # old name
  693. n_head_kv = self.hparams.get("num_kv_heads")
  694. if n_head_kv is None:
  695. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  696. head_dim = self.hparams["hidden_size"] // n_head
  697. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  698. for name, data_torch in self.get_tensors():
  699. old_dtype = data_torch.dtype
  700. # convert any unsupported data types to float32
  701. if data_torch.dtype not in (torch.float16, torch.float32):
  702. data_torch = data_torch.to(torch.float32)
  703. # QKV tensor transform
  704. # The original query_key_value tensor contains n_head_kv "kv groups",
  705. # each consisting of n_head/n_head_kv query weights followed by one key
  706. # and one value weight (shared by all query heads in the kv group).
  707. # This layout makes it a big pain to work with in GGML.
  708. # So we rearrange them here,, so that we have n_head query weights
  709. # followed by n_head_kv key weights followed by n_head_kv value weights,
  710. # in contiguous fashion.
  711. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  712. if "query_key_value" in name:
  713. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  714. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  715. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  716. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  717. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  718. data = data_torch.squeeze().numpy()
  719. # map tensor names
  720. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  721. if new_name is None:
  722. print(f"Can not map tensor {name!r}")
  723. sys.exit()
  724. n_dims = len(data.shape)
  725. data_dtype = data.dtype
  726. # if f32 desired, convert any float16 to float32
  727. if self.ftype == 0 and data_dtype == np.float16:
  728. data = data.astype(np.float32)
  729. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  730. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  731. data = data.astype(np.float32)
  732. # if f16 desired, convert any float32 2-dim weight tensors to float16
  733. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  734. data = data.astype(np.float16)
  735. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  736. self.gguf_writer.add_tensor(new_name, data)
  737. class StarCoderModel(Model):
  738. def set_gguf_parameters(self):
  739. block_count = self.hparams["n_layer"]
  740. self.gguf_writer.add_name("StarCoder")
  741. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  742. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  743. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  744. self.gguf_writer.add_block_count(block_count)
  745. self.gguf_writer.add_head_count(self.hparams["n_head"])
  746. self.gguf_writer.add_head_count_kv(1)
  747. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  748. self.gguf_writer.add_file_type(self.ftype)
  749. class RefactModel(Model):
  750. def set_gguf_parameters(self):
  751. hidden_dim = self.hparams["n_embd"]
  752. inner_dim = 4 * hidden_dim
  753. hidden_dim = int(2 * inner_dim / 3)
  754. multiple_of = 256
  755. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  756. block_count = self.hparams["n_layer"]
  757. self.gguf_writer.add_name("Refact")
  758. # refact uses Alibi. So this is from config.json which might be used by training.
  759. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  760. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  761. self.gguf_writer.add_feed_forward_length(ff_dim)
  762. self.gguf_writer.add_block_count(block_count)
  763. self.gguf_writer.add_head_count(self.hparams["n_head"])
  764. self.gguf_writer.add_head_count_kv(1)
  765. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  766. self.gguf_writer.add_file_type(self.ftype)
  767. def write_tensors(self):
  768. hidden_dim = self.hparams["n_embd"]
  769. inner_dim = 4 * hidden_dim
  770. hidden_dim = int(2 * inner_dim / 3)
  771. multiple_of = 256
  772. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  773. n_head = self.hparams["n_head"]
  774. n_head_kv = 1
  775. head_dim = self.hparams["n_embd"] // n_head
  776. block_count = self.hparams["n_layer"]
  777. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  778. tensors = dict(self.get_tensors())
  779. for i in range(block_count):
  780. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  781. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  782. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  783. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  784. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  785. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  786. del tensors[f"transformer.h.{i}.attn.q.weight"]
  787. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  788. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  789. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  790. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  791. for name, data_torch in tensors.items():
  792. old_dtype = data_torch.dtype
  793. # convert any unsupported data types to float32
  794. if data_torch.dtype not in (torch.float16, torch.float32):
  795. data_torch = data_torch.to(torch.float32)
  796. data = data_torch.squeeze().numpy()
  797. # map tensor names
  798. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  799. if new_name is None:
  800. print(f"Can not map tensor {name!r}")
  801. sys.exit()
  802. n_dims = len(data.shape)
  803. data_dtype = data.dtype
  804. # if f32 desired, convert any float16 to float32
  805. if self.ftype == 0 and data_dtype == np.float16:
  806. data = data.astype(np.float32)
  807. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  808. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  809. data = data.astype(np.float32)
  810. # if f16 desired, convert any float32 2-dim weight tensors to float16
  811. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  812. data = data.astype(np.float16)
  813. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  814. self.gguf_writer.add_tensor(new_name, data)
  815. class PersimmonModel(Model):
  816. def set_gguf_parameters(self):
  817. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  818. head_count = self.hparams["num_attention_heads"]
  819. head_count_kv = head_count
  820. hidden_size = self.hparams["hidden_size"]
  821. self.gguf_writer.add_name('persimmon-8b-chat')
  822. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  823. self.gguf_writer.add_embedding_length(hidden_size)
  824. self.gguf_writer.add_block_count(block_count)
  825. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  826. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  827. # than the head size?
  828. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  829. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  830. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  831. self.gguf_writer.add_head_count(head_count)
  832. self.gguf_writer.add_head_count_kv(head_count_kv)
  833. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  834. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  835. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  836. def set_vocab(self):
  837. self._set_vocab_sentencepiece()
  838. # self.gguf_writer.add_bos_token_id(71013)
  839. # self.gguf_writer.add_eos_token_id(71013)
  840. def write_tensors(self):
  841. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  842. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  843. for name, data_torch in self.get_tensors():
  844. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  845. continue
  846. old_dtype = data_torch.dtype
  847. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  848. data = data_torch.to(torch.float32).squeeze().numpy()
  849. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  850. if new_name is None:
  851. print(f"Can not map tensor {name!r}")
  852. sys.exit()
  853. n_dims = len(data.shape)
  854. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  855. self.gguf_writer.add_tensor(new_name, data)
  856. class StableLMModel(Model):
  857. def set_vocab(self):
  858. if (self.dir_model / "tokenizer.json").is_file():
  859. self._set_vocab_gpt2()
  860. else:
  861. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  862. self._set_vocab_qwen()
  863. def set_gguf_parameters(self):
  864. hparams = self.hparams
  865. block_count = hparams["num_hidden_layers"]
  866. self.gguf_writer.add_name(self.dir_model.name)
  867. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  868. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  869. self.gguf_writer.add_block_count(block_count)
  870. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  871. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  872. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  873. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  874. self.gguf_writer.add_layer_norm_eps(1e-5)
  875. class MixtralModel(Model):
  876. def set_vocab(self):
  877. self._set_vocab_sentencepiece()
  878. class MiniCPMModel(Model):
  879. def set_gguf_parameters(self):
  880. block_count = self.hparams["num_hidden_layers"]
  881. self.gguf_writer.add_name("MiniCPM")
  882. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  883. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  884. self.gguf_writer.add_block_count(block_count)
  885. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  886. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  887. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  888. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  889. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  890. self.gguf_writer.add_file_type(self.ftype)
  891. def set_vocab(self):
  892. self._set_vocab_hf()
  893. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  894. if n_kv_head is not None and n_head != n_kv_head:
  895. n_head //= n_kv_head
  896. return (
  897. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  898. .swapaxes(1, 2)
  899. .reshape(weights.shape)
  900. )
  901. def write_tensors(self):
  902. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  903. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  904. n_head = self.hparams.get("num_attention_heads")
  905. n_kv_head = self.hparams.get("num_key_value_heads")
  906. for name, data_torch in self.get_tensors():
  907. # we don't need these
  908. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  909. continue
  910. old_dtype = data_torch.dtype
  911. # convert any unsupported data types to float32
  912. if data_torch.dtype not in (torch.float16, torch.float32):
  913. data_torch = data_torch.to(torch.float32)
  914. # HF models permute some of the tensors, so we need to undo that
  915. if name.endswith(("q_proj.weight")):
  916. data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
  917. if name.endswith(("k_proj.weight")):
  918. data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
  919. data = data_torch.squeeze().numpy()
  920. # map tensor names
  921. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  922. if new_name is None:
  923. print(f"Can not map tensor {name!r}")
  924. sys.exit()
  925. n_dims = len(data.shape)
  926. data_dtype = data.dtype
  927. # if f32 desired, convert any float16 to float32
  928. if self.ftype == 0 and data_dtype == np.float16:
  929. data = data.astype(np.float32)
  930. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  931. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  932. data = data.astype(np.float32)
  933. # if f16 desired, convert any float32 2-dim weight tensors to float16
  934. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  935. data = data.astype(np.float16)
  936. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  937. self.gguf_writer.add_tensor(new_name, data)
  938. class QwenModel(Model):
  939. @staticmethod
  940. def token_bytes_to_string(b):
  941. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  942. byte_encoder = bytes_to_unicode()
  943. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  944. @staticmethod
  945. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  946. parts = [bytes([b]) for b in token]
  947. while True:
  948. min_idx = None
  949. min_rank = None
  950. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  951. rank = mergeable_ranks.get(pair[0] + pair[1])
  952. if rank is not None and (min_rank is None or rank < min_rank):
  953. min_idx = i
  954. min_rank = rank
  955. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  956. break
  957. assert min_idx is not None
  958. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  959. return parts
  960. def set_vocab(self):
  961. self._set_vocab_qwen()
  962. def set_gguf_parameters(self):
  963. self.gguf_writer.add_name("Qwen")
  964. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  965. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  966. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  967. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  968. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  969. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  970. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  971. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  972. def write_tensors(self):
  973. block_count = self.hparams["num_hidden_layers"]
  974. model_kv = dict(self.get_tensors())
  975. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  976. for name, data_torch in model_kv.items():
  977. # we don't need these
  978. if name.endswith(".rotary_emb.inv_freq"):
  979. continue
  980. old_dtype = data_torch.dtype
  981. # convert any unsupported data types to float32
  982. if data_torch.dtype not in (torch.float16, torch.float32):
  983. data_torch = data_torch.to(torch.float32)
  984. data = data_torch.squeeze().numpy()
  985. # map tensor names
  986. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  987. if new_name is None:
  988. print(f"Can not map tensor {name!r}")
  989. sys.exit()
  990. n_dims = len(data.shape)
  991. data_dtype = data.dtype
  992. # if f32 desired, convert any float16 to float32
  993. if self.ftype == 0 and data_dtype == np.float16:
  994. data = data.astype(np.float32)
  995. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  996. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  997. data = data.astype(np.float32)
  998. # if f16 desired, convert any float32 2-dim weight tensors to float16
  999. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1000. data = data.astype(np.float16)
  1001. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1002. self.gguf_writer.add_tensor(new_name, data)
  1003. class GPT2Model(Model):
  1004. def set_gguf_parameters(self):
  1005. self.gguf_writer.add_name(self.dir_model.name)
  1006. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1007. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  1008. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1009. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1010. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1011. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1012. self.gguf_writer.add_file_type(self.ftype)
  1013. def write_tensors(self):
  1014. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1015. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1016. for name, data_torch in self.get_tensors():
  1017. # we don't need these
  1018. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")):
  1019. continue
  1020. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  1021. data_torch = data_torch.transpose(1, 0)
  1022. old_dtype = data_torch.dtype
  1023. # convert any unsupported data types to float32
  1024. if data_torch.dtype not in (torch.float16, torch.float32):
  1025. data_torch = data_torch.to(torch.float32)
  1026. data = data_torch.squeeze().numpy()
  1027. # map tensor names
  1028. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1029. if new_name is None:
  1030. print(f"Can not map tensor {name!r}")
  1031. sys.exit()
  1032. n_dims = len(data.shape)
  1033. data_dtype = data.dtype
  1034. # if f32 desired, convert any float16 to float32
  1035. if self.ftype == 0 and data_dtype == np.float16:
  1036. data = data.astype(np.float32)
  1037. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1038. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1039. data = data.astype(np.float32)
  1040. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1041. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1042. data = data.astype(np.float16)
  1043. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1044. self.gguf_writer.add_tensor(new_name, data)
  1045. # note: GPT2 output is tied to (same as) wte in original model
  1046. if new_name == "token_embd.weight":
  1047. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1048. self.gguf_writer.add_tensor("output.weight", data)
  1049. class Phi2Model(Model):
  1050. def set_gguf_parameters(self):
  1051. block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
  1052. rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
  1053. n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
  1054. n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
  1055. self.gguf_writer.add_name("Phi2")
  1056. self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
  1057. self.gguf_writer.add_embedding_length(n_embd)
  1058. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  1059. self.gguf_writer.add_block_count(block_count)
  1060. self.gguf_writer.add_head_count(n_head)
  1061. self.gguf_writer.add_head_count_kv(n_head)
  1062. self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
  1063. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1064. self.gguf_writer.add_file_type(self.ftype)
  1065. self.gguf_writer.add_add_bos_token(False)
  1066. class PlamoModel(Model):
  1067. def set_vocab(self):
  1068. self._set_vocab_sentencepiece()
  1069. def set_gguf_parameters(self):
  1070. hparams = self.hparams
  1071. block_count = hparams["num_hidden_layers"]
  1072. self.gguf_writer.add_name("PLaMo")
  1073. self.gguf_writer.add_context_length(4096) # not in config.json
  1074. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1075. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1076. self.gguf_writer.add_block_count(block_count)
  1077. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1078. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  1079. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  1080. def shuffle_attn_q_weight(self, data_torch):
  1081. assert data_torch.size() == (5120, 5120)
  1082. data_torch = data_torch.reshape(8, 5, 128, 5120)
  1083. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  1084. data_torch = torch.reshape(data_torch, (5120, 5120))
  1085. return data_torch
  1086. def shuffle_attn_output_weight(self, data_torch):
  1087. assert data_torch.size() == (5120, 5120)
  1088. data_torch = data_torch.reshape(5120, 8, 5, 128)
  1089. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  1090. data_torch = torch.reshape(data_torch, (5120, 5120))
  1091. return data_torch
  1092. def write_tensors(self):
  1093. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  1094. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1095. for name, data_torch in self.get_tensors():
  1096. if "self_attn.rotary_emb.inv_freq" in name:
  1097. continue
  1098. # map tensor names
  1099. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1100. if new_name is None:
  1101. print(f"Can not map tensor {name!r}")
  1102. sys.exit()
  1103. # shuffle for broadcasting of gqa in ggml_mul_mat
  1104. if new_name.endswith("attn_q.weight"):
  1105. data_torch = self.shuffle_attn_q_weight(data_torch)
  1106. elif new_name.endswith("attn_output.weight"):
  1107. data_torch = self.shuffle_attn_output_weight(data_torch)
  1108. old_dtype = data_torch.dtype
  1109. # convert any unsupported data types to float32
  1110. if data_torch.dtype not in (torch.float16, torch.float32):
  1111. data_torch = data_torch.to(torch.float32)
  1112. data = data_torch.squeeze().numpy()
  1113. n_dims = len(data.shape)
  1114. data_dtype = data.dtype
  1115. # if f32 desired, convert any float16 to float32
  1116. if self.ftype == 0 and data_dtype == np.float16:
  1117. data = data.astype(np.float32)
  1118. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1119. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1120. data = data.astype(np.float32)
  1121. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1122. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1123. data = data.astype(np.float16)
  1124. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1125. self.gguf_writer.add_tensor(new_name, data)
  1126. class CodeShellModel(Model):
  1127. def set_gguf_parameters(self):
  1128. block_count = self.hparams["n_layer"]
  1129. self.gguf_writer.add_name("CodeShell")
  1130. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1131. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1132. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1133. self.gguf_writer.add_block_count(block_count)
  1134. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1135. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1136. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1137. self.gguf_writer.add_file_type(self.ftype)
  1138. self.gguf_writer.add_rope_freq_base(10000.0)
  1139. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1140. self.gguf_writer.add_rope_scaling_factor(1.0)
  1141. def write_tensors(self):
  1142. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  1143. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1144. tensors = dict(self.get_tensors())
  1145. has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
  1146. for name, data_torch in tensors.items():
  1147. # we don't need these
  1148. if name.endswith((".attn.rotary_emb.inv_freq")):
  1149. continue
  1150. old_dtype = data_torch.dtype
  1151. # convert any unsupported data types to float32
  1152. if data_torch.dtype not in (torch.float16, torch.float32):
  1153. data_torch = data_torch.to(torch.float32)
  1154. data = data_torch.squeeze().numpy()
  1155. # map tensor names
  1156. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1157. if new_name is None:
  1158. print(f"Can not map tensor {name!r}")
  1159. sys.exit()
  1160. n_dims = len(data.shape)
  1161. data_dtype = data.dtype
  1162. # if f32 desired, convert any float16 to float32
  1163. if self.ftype == 0 and data_dtype == np.float16:
  1164. data = data.astype(np.float32)
  1165. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1166. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1167. data = data.astype(np.float32)
  1168. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1169. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1170. data = data.astype(np.float16)
  1171. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1172. self.gguf_writer.add_tensor(new_name, data)
  1173. if not has_lm_head and name == "transformer.wte.weight":
  1174. self.gguf_writer.add_tensor("output.weight", data)
  1175. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  1176. class InternLM2Model(Model):
  1177. def set_vocab(self):
  1178. # (TODO): Is there a better way?
  1179. # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
  1180. # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
  1181. # recognized as an empty string in C++.
  1182. from sentencepiece import SentencePieceProcessor
  1183. from sentencepiece import sentencepiece_model_pb2 as model
  1184. tokenizer_path = self.dir_model / 'tokenizer.model'
  1185. tokens: list[bytes] = []
  1186. scores: list[float] = []
  1187. toktypes: list[int] = []
  1188. if not tokenizer_path.is_file():
  1189. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  1190. sys.exit(1)
  1191. sentencepiece_model = model.ModelProto()
  1192. sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
  1193. add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
  1194. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  1195. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1196. for token_id in range(vocab_size):
  1197. piece = tokenizer.id_to_piece(token_id)
  1198. text = piece.encode("utf-8")
  1199. score = tokenizer.get_score(token_id)
  1200. if text == b"\x00":
  1201. # (TODO): fixme
  1202. # Hack here and replace the \x00 characters.
  1203. print(f"InternLM2 convert token '{text}' to '🐉'!")
  1204. text = "🐉"
  1205. toktype = SentencePieceTokenTypes.NORMAL
  1206. if tokenizer.is_unknown(token_id):
  1207. toktype = SentencePieceTokenTypes.UNKNOWN
  1208. elif tokenizer.is_control(token_id):
  1209. toktype = SentencePieceTokenTypes.CONTROL
  1210. elif tokenizer.is_unused(token_id):
  1211. toktype = SentencePieceTokenTypes.UNUSED
  1212. elif tokenizer.is_byte(token_id):
  1213. toktype = SentencePieceTokenTypes.BYTE
  1214. tokens.append(text)
  1215. scores.append(score)
  1216. toktypes.append(toktype)
  1217. added_tokens_file = self.dir_model / 'added_tokens.json'
  1218. if added_tokens_file.is_file():
  1219. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1220. added_tokens_json = json.load(f)
  1221. for key in added_tokens_json:
  1222. tokens.append(key.encode("utf-8"))
  1223. scores.append(-1000.0)
  1224. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  1225. self.gguf_writer.add_tokenizer_model("llama")
  1226. self.gguf_writer.add_token_list(tokens)
  1227. self.gguf_writer.add_token_scores(scores)
  1228. self.gguf_writer.add_token_types(toktypes)
  1229. self.gguf_writer.add_add_space_prefix(add_prefix)
  1230. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1231. old_eos = special_vocab.special_token_ids["eos"]
  1232. if "chat" in os.path.basename(self.dir_model.absolute()):
  1233. # For the chat model, we replace the eos with '<|im_end|>'.
  1234. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
  1235. print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
  1236. in chat mode so that the conversation can end normally.")
  1237. special_vocab.add_to_gguf(self.gguf_writer)
  1238. def _try_get_sft_eos(self, tokenizer):
  1239. unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
  1240. im_end_list = tokenizer.encode('<|im_end|>')
  1241. assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
  1242. if len(unused_145_list) == 1:
  1243. eos_token = unused_145_list[0]
  1244. if len(im_end_list) == 1:
  1245. eos_token = im_end_list[0]
  1246. return eos_token
  1247. def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
  1248. if n_head_kv is not None and n_head != n_head_kv:
  1249. n_head = n_head_kv
  1250. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1251. .swapaxes(1, 2)
  1252. .reshape(weights.shape))
  1253. def set_gguf_parameters(self):
  1254. self.gguf_writer.add_name("InternLM2")
  1255. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1256. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1257. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1258. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1259. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  1260. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1261. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1262. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1263. def post_write_tensors(self, tensor_map, name, data_torch):
  1264. old_dtype = data_torch.dtype
  1265. # convert any unsupported data types to float32
  1266. if data_torch.dtype not in (torch.float16, torch.float32):
  1267. data_torch = data_torch.to(torch.float32)
  1268. data = data_torch.squeeze().numpy()
  1269. # map tensor names
  1270. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1271. if new_name is None:
  1272. print(f"Can not map tensor {name!r}")
  1273. sys.exit()
  1274. n_dims = len(data.shape)
  1275. data_dtype = data.dtype
  1276. # if f32 desired, convert any float16 to float32
  1277. if self.ftype == 0 and data_dtype == np.float16:
  1278. data = data.astype(np.float32)
  1279. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  1280. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  1281. data = data.astype(np.float32)
  1282. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1283. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  1284. data = data.astype(np.float16)
  1285. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  1286. self.gguf_writer.add_tensor(new_name, data)
  1287. def write_tensors(self):
  1288. from einops import rearrange
  1289. num_heads = self.hparams.get("num_attention_heads")
  1290. num_kv_heads = self.hparams.get("num_key_value_heads")
  1291. hidden_size = self.hparams.get("hidden_size")
  1292. q_per_kv = num_heads // num_kv_heads
  1293. head_dim = hidden_size // num_heads
  1294. num_groups = num_heads // q_per_kv
  1295. block_count = self.hparams["num_hidden_layers"]
  1296. model_kv = dict(self.get_tensors())
  1297. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  1298. qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
  1299. for name, data_torch in model_kv.items():
  1300. # we don't need these
  1301. if name.endswith(".rotary_emb.inv_freq"):
  1302. continue
  1303. if re.match(qkv_pattern, name):
  1304. bid = re.findall(qkv_pattern, name)[0]
  1305. qkv = data_torch
  1306. qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
  1307. q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
  1308. # The model weights of q and k equire additional reshape.
  1309. q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
  1310. k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
  1311. v = rearrange(v, " o g n i -> o (g n i)").T
  1312. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
  1313. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
  1314. self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
  1315. else:
  1316. self.post_write_tensors(tensor_map, name, data_torch)
  1317. class BertModel(Model):
  1318. def __init__(self, *args, **kwargs):
  1319. super().__init__(*args, **kwargs)
  1320. self.block_count = self.hparams["num_hidden_layers"]
  1321. def set_gguf_parameters(self):
  1322. # TODO(cebtenzzre): merge with parent class
  1323. self.gguf_writer.add_name(self.dir_model.name)
  1324. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1325. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1326. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1327. self.gguf_writer.add_block_count(self.block_count)
  1328. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1329. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  1330. self.gguf_writer.add_causal_attention(False)
  1331. self.gguf_writer.add_file_type(self.ftype)
  1332. def set_vocab(self):
  1333. path = self.dir_model
  1334. added_tokens_path = self.dir_model if self.dir_model.exists() else None
  1335. # use huggingface vocab to get all tokens
  1336. vocab = HfVocab(path, added_tokens_path)
  1337. tokens, scores, toktypes = zip(*vocab.all_tokens())
  1338. assert len(tokens) == vocab.vocab_size
  1339. # we need this to validate the size of the token_type embeddings
  1340. # though currently we are passing all zeros to the token_type embeddings
  1341. n_token_types = len(set(toktypes))
  1342. self.gguf_writer.add_token_type_count(n_token_types)
  1343. # convert to phantom space vocab
  1344. def phantom(tok, typ):
  1345. if tok.startswith(b"[") and tok.endswith(b"]"):
  1346. return tok
  1347. if tok.startswith(b"##"):
  1348. return tok[2:]
  1349. return b"\xe2\x96\x81" + tok
  1350. tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
  1351. # set up bos and eos tokens (cls and sep)
  1352. self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
  1353. self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
  1354. # add vocab to gguf
  1355. self.gguf_writer.add_tokenizer_model("bert")
  1356. self.gguf_writer.add_token_list(tokens)
  1357. self.gguf_writer.add_token_scores(scores)
  1358. self.gguf_writer.add_token_types(toktypes)
  1359. # handle special tokens
  1360. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1361. special_vocab.add_to_gguf(self.gguf_writer)
  1362. def write_tensors(self):
  1363. tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
  1364. tensors = dict(self.get_tensors())
  1365. for name, data_torch in tensors.items():
  1366. # we are only using BERT for embeddings so we don't need the pooling layer
  1367. if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
  1368. continue # we don't need these
  1369. # map tensor names
  1370. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  1371. if new_name is None:
  1372. print(f"Can not map tensor {name!r}")
  1373. sys.exit()
  1374. data = data_torch.squeeze().numpy()
  1375. n_dims = len(data.shape)
  1376. new_dtype: type[np.floating[Any]]
  1377. if (
  1378. self.ftype == 1 and name.endswith(".weight") and n_dims == 2
  1379. and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
  1380. ):
  1381. # if f16 desired, convert any float32 2-dim weight tensors to float16
  1382. new_dtype = np.float16
  1383. else:
  1384. # if f32 desired, convert any float16 to float32
  1385. new_dtype = np.float32
  1386. print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
  1387. if data.dtype != new_dtype:
  1388. data = data.astype(new_dtype)
  1389. self.gguf_writer.add_tensor(new_name, data)
  1390. ###### CONVERSION LOGIC ######
  1391. def parse_args() -> argparse.Namespace:
  1392. parser = argparse.ArgumentParser(
  1393. description="Convert a huggingface model to a GGML compatible file")
  1394. parser.add_argument(
  1395. "--vocab-only", action="store_true",
  1396. help="extract only the vocab",
  1397. )
  1398. parser.add_argument(
  1399. "--awq-path", type=Path, default=None,
  1400. help="Path to scale awq cache file")
  1401. parser.add_argument(
  1402. "--outfile", type=Path,
  1403. help="path to write to; default: based on input",
  1404. )
  1405. parser.add_argument(
  1406. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  1407. help="output format - use f32 for float32, f16 for float16",
  1408. )
  1409. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  1410. parser.add_argument(
  1411. "model", type=Path,
  1412. help="directory containing model file",
  1413. )
  1414. return parser.parse_args()
  1415. def main() -> None:
  1416. args = parse_args()
  1417. dir_model = args.model
  1418. if args.awq_path:
  1419. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1420. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1421. tmp_model_path = args.model / "weighted_model"
  1422. dir_model = tmp_model_path
  1423. if tmp_model_path.is_dir():
  1424. print(f"{tmp_model_path} exists as a weighted model.")
  1425. else:
  1426. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1427. print("Saving new weighted model ...")
  1428. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1429. print(f"Saved weighted model at {tmp_model_path}.")
  1430. if not dir_model.is_dir():
  1431. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  1432. sys.exit(1)
  1433. ftype_map = {
  1434. "f32": gguf.GGMLQuantizationType.F32,
  1435. "f16": gguf.GGMLQuantizationType.F16,
  1436. }
  1437. if args.outfile is not None:
  1438. fname_out = args.outfile
  1439. else:
  1440. # output in the same directory as the model by default
  1441. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1442. print(f"Loading model: {dir_model.name}")
  1443. hparams = Model.load_hparams(dir_model)
  1444. with torch.inference_mode():
  1445. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1446. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1447. print("Set model parameters")
  1448. model_instance.set_gguf_parameters()
  1449. print("Set model tokenizer")
  1450. model_instance.set_vocab()
  1451. if args.vocab_only:
  1452. print(f"Exporting model vocab to '{fname_out}'")
  1453. model_instance.write_vocab()
  1454. else:
  1455. print(f"Exporting model to '{fname_out}'")
  1456. model_instance.write()
  1457. print(f"Model successfully exported to '{fname_out}'")
  1458. if __name__ == '__main__':
  1459. main()