convert-hf-to-gguf.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. # check for any of the given keys in the dictionary and return the value of the first key found
  20. def get_key_opts(d, keys):
  21. for k in keys:
  22. if k in d:
  23. return d[k]
  24. print(f"Could not find any of {keys}")
  25. sys.exit()
  26. ###### MODEL DEFINITIONS ######
  27. class SentencePieceTokenTypes(IntEnum):
  28. NORMAL = 1
  29. UNKNOWN = 2
  30. CONTROL = 3
  31. USER_DEFINED = 4
  32. UNUSED = 5
  33. BYTE = 6
  34. class Model:
  35. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  36. self.dir_model = dir_model
  37. self.ftype = ftype
  38. self.fname_out = fname_out
  39. self.is_big_endian = is_big_endian
  40. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  41. self.is_safetensors = self._is_model_safetensors()
  42. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  43. self.part_names = self._get_part_names()
  44. self.hparams = Model.load_hparams(self.dir_model)
  45. self.model_arch = self._get_model_architecture()
  46. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
  47. def set_vocab(self):
  48. self._set_vocab_gpt2()
  49. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  50. for part_name in self.part_names:
  51. print(f"gguf: loading model part '{part_name}'")
  52. ctx: ContextManager[Any]
  53. if self.is_safetensors:
  54. from safetensors import safe_open
  55. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  56. else:
  57. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  58. with ctx as model_part:
  59. for name in model_part.keys():
  60. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  61. yield name, data
  62. def set_gguf_parameters(self):
  63. self.gguf_writer.add_name(self.dir_model.name)
  64. self.gguf_writer.add_block_count(self.hparams.get(
  65. "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
  66. ))
  67. if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
  68. self.gguf_writer.add_context_length(n_ctx)
  69. if (n_embd := self.hparams.get("hidden_size")) is not None:
  70. self.gguf_writer.add_embedding_length(n_embd)
  71. if (n_ff := self.hparams.get("intermediate_size")) is not None:
  72. self.gguf_writer.add_feed_forward_length(n_ff)
  73. if (n_head := self.hparams.get("num_attention_heads")) is not None:
  74. self.gguf_writer.add_head_count(n_head)
  75. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  76. self.gguf_writer.add_head_count_kv(n_head_kv)
  77. if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  78. self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
  79. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  80. self.gguf_writer.add_expert_count(n_experts)
  81. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  82. self.gguf_writer.add_expert_used_count(n_experts_used)
  83. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  84. def write_tensors(self):
  85. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  86. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  87. for name, data_torch in self.get_tensors():
  88. # we don't need these
  89. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  90. continue
  91. old_dtype = data_torch.dtype
  92. # convert any unsupported data types to float32
  93. if data_torch.dtype not in (torch.float16, torch.float32):
  94. data_torch = data_torch.to(torch.float32)
  95. data = data_torch.squeeze().numpy()
  96. # map tensor names
  97. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  98. if new_name is None:
  99. print(f"Can not map tensor {name!r}")
  100. sys.exit()
  101. n_dims = len(data.shape)
  102. data_dtype = data.dtype
  103. # if f32 desired, convert any float16 to float32
  104. if self.ftype == 0 and data_dtype == np.float16:
  105. data = data.astype(np.float32)
  106. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  107. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  108. data = data.astype(np.float32)
  109. # if f16 desired, convert any float32 2-dim weight tensors to float16
  110. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  111. data = data.astype(np.float16)
  112. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  113. self.gguf_writer.add_tensor(new_name, data)
  114. def write(self):
  115. self.write_tensors()
  116. self.gguf_writer.write_header_to_file()
  117. self.gguf_writer.write_kv_data_to_file()
  118. self.gguf_writer.write_tensors_to_file()
  119. self.gguf_writer.close()
  120. def write_vocab(self):
  121. self.gguf_writer.write_header_to_file()
  122. self.gguf_writer.write_kv_data_to_file()
  123. self.gguf_writer.close()
  124. @staticmethod
  125. def count_model_parts(dir_model: Path, prefix: str) -> int:
  126. num_parts = 0
  127. for filename in os.listdir(dir_model):
  128. if filename.endswith(prefix):
  129. num_parts += 1
  130. return num_parts
  131. @staticmethod
  132. def load_hparams(dir_model):
  133. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  134. return json.load(f)
  135. @staticmethod
  136. def from_model_architecture(model_architecture):
  137. if model_architecture == "GPTNeoXForCausalLM":
  138. return GPTNeoXModel
  139. if model_architecture == "BloomForCausalLM":
  140. return BloomModel
  141. if model_architecture == "MPTForCausalLM":
  142. return MPTModel
  143. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  144. return BaichuanModel
  145. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  146. return FalconModel
  147. if model_architecture == "GPTBigCodeForCausalLM":
  148. return StarCoderModel
  149. if model_architecture == "GPTRefactForCausalLM":
  150. return RefactModel
  151. if model_architecture == "PersimmonForCausalLM":
  152. return PersimmonModel
  153. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  154. return StableLMModel
  155. if model_architecture == "QWenLMHeadModel":
  156. return QwenModel
  157. if model_architecture == "MixtralForCausalLM":
  158. return MixtralModel
  159. if model_architecture == "GPT2LMHeadModel":
  160. return GPT2Model
  161. if model_architecture == "PhiForCausalLM":
  162. return Phi2Model
  163. if model_architecture == "PlamoForCausalLM":
  164. return PlamoModel
  165. return Model
  166. def _is_model_safetensors(self) -> bool:
  167. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  168. def _get_part_names(self):
  169. if self.is_safetensors:
  170. if self.num_parts == 1: # there's only one .safetensors file
  171. return ("model.safetensors",)
  172. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  173. if self.num_parts == 1: # there's only one .bin file
  174. return ("pytorch_model.bin",)
  175. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  176. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  177. arch = self.hparams["architectures"][0]
  178. if arch == "GPTNeoXForCausalLM":
  179. return gguf.MODEL_ARCH.GPTNEOX
  180. if arch == "BloomForCausalLM":
  181. return gguf.MODEL_ARCH.BLOOM
  182. if arch == "MPTForCausalLM":
  183. return gguf.MODEL_ARCH.MPT
  184. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  185. return gguf.MODEL_ARCH.BAICHUAN
  186. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  187. return gguf.MODEL_ARCH.FALCON
  188. if arch == "GPTBigCodeForCausalLM":
  189. return gguf.MODEL_ARCH.STARCODER
  190. if arch == "GPTRefactForCausalLM":
  191. return gguf.MODEL_ARCH.REFACT
  192. if arch == "PersimmonForCausalLM":
  193. return gguf.MODEL_ARCH.PERSIMMON
  194. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  195. return gguf.MODEL_ARCH.STABLELM
  196. if arch == "QWenLMHeadModel":
  197. return gguf.MODEL_ARCH.QWEN
  198. if arch == "MixtralForCausalLM":
  199. return gguf.MODEL_ARCH.LLAMA
  200. if arch == "GPT2LMHeadModel":
  201. return gguf.MODEL_ARCH.GPT2
  202. if arch == "PhiForCausalLM":
  203. return gguf.MODEL_ARCH.PHI2
  204. if arch == "PlamoForCausalLM":
  205. return gguf.MODEL_ARCH.PLAMO
  206. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  207. def _set_vocab_gpt2(self):
  208. dir_model = self.dir_model
  209. hparams = self.hparams
  210. tokens: list[bytearray] = []
  211. toktypes: list[int] = []
  212. from transformers import AutoTokenizer
  213. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  214. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  215. assert max(tokenizer.vocab.values()) < vocab_size
  216. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  217. added_vocab = tokenizer.get_added_vocab()
  218. for i in range(vocab_size):
  219. if i not in reverse_vocab:
  220. pad_token = f"[PAD{i}]".encode('utf-8')
  221. tokens.append(bytearray(pad_token))
  222. toktypes.append(gguf.TokenType.USER_DEFINED)
  223. elif reverse_vocab[i] in added_vocab:
  224. tokens.append(reverse_vocab[i])
  225. if hasattr(tokenizer, "added_tokens_decoder"):
  226. if tokenizer.added_tokens_decoder[i].special:
  227. toktypes.append(gguf.TokenType.CONTROL)
  228. else:
  229. toktypes.append(gguf.TokenType.USER_DEFINED)
  230. else:
  231. tokens.append(reverse_vocab[i])
  232. toktypes.append(gguf.TokenType.NORMAL)
  233. self.gguf_writer.add_tokenizer_model("gpt2")
  234. self.gguf_writer.add_token_list(tokens)
  235. self.gguf_writer.add_token_types(toktypes)
  236. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  237. special_vocab.add_to_gguf(self.gguf_writer)
  238. def _set_vocab_sentencepiece(self):
  239. from sentencepiece import SentencePieceProcessor
  240. tokenizer_path = self.dir_model / 'tokenizer.model'
  241. tokens: list[bytes] = []
  242. scores: list[float] = []
  243. toktypes: list[int] = []
  244. if not tokenizer_path.is_file():
  245. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  246. sys.exit(1)
  247. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  248. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  249. for token_id in range(vocab_size):
  250. piece = tokenizer.id_to_piece(token_id)
  251. text = piece.encode("utf-8")
  252. score = tokenizer.get_score(token_id)
  253. toktype = SentencePieceTokenTypes.NORMAL
  254. if tokenizer.is_unknown(token_id):
  255. toktype = SentencePieceTokenTypes.UNKNOWN
  256. elif tokenizer.is_control(token_id):
  257. toktype = SentencePieceTokenTypes.CONTROL
  258. elif tokenizer.is_unused(token_id):
  259. toktype = SentencePieceTokenTypes.UNUSED
  260. elif tokenizer.is_byte(token_id):
  261. toktype = SentencePieceTokenTypes.BYTE
  262. tokens.append(text)
  263. scores.append(score)
  264. toktypes.append(toktype)
  265. added_tokens_file = self.dir_model / 'added_tokens.json'
  266. if added_tokens_file.is_file():
  267. with open(added_tokens_file, "r", encoding="utf-8") as f:
  268. added_tokens_json = json.load(f)
  269. for key in added_tokens_json:
  270. tokens.append(key.encode("utf-8"))
  271. scores.append(-1000.0)
  272. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  273. self.gguf_writer.add_tokenizer_model("llama")
  274. self.gguf_writer.add_token_list(tokens)
  275. self.gguf_writer.add_token_scores(scores)
  276. self.gguf_writer.add_token_types(toktypes)
  277. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  278. special_vocab.add_to_gguf(self.gguf_writer)
  279. class GPTNeoXModel(Model):
  280. def set_gguf_parameters(self):
  281. block_count = self.hparams["num_hidden_layers"]
  282. self.gguf_writer.add_name(self.dir_model.name)
  283. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  284. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  285. self.gguf_writer.add_block_count(block_count)
  286. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  287. self.gguf_writer.add_rope_dimension_count(
  288. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  289. )
  290. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  291. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  292. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  293. class BloomModel(Model):
  294. def set_gguf_parameters(self):
  295. self.gguf_writer.add_name("Bloom")
  296. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  297. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  298. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  299. self.gguf_writer.add_embedding_length(n_embed)
  300. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  301. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  302. self.gguf_writer.add_head_count(n_head)
  303. self.gguf_writer.add_head_count_kv(n_head)
  304. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  305. self.gguf_writer.add_file_type(self.ftype)
  306. def write_tensors(self):
  307. block_count = self.hparams["n_layer"]
  308. tensors = dict(self.get_tensors())
  309. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  310. has_lm_head = True
  311. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  312. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  313. for name, data_torch in tensors.items():
  314. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  315. has_lm_head = False
  316. name = re.sub(r'transformer\.', '', name)
  317. old_dtype = data_torch.dtype
  318. # convert any unsupported data types to float32
  319. if data_torch.dtype not in (torch.float16, torch.float32):
  320. data_torch = data_torch.to(torch.float32)
  321. data = data_torch.squeeze().numpy()
  322. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  323. # Map bloom-style qkv_linear to gpt-style qkv_linear
  324. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  325. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  326. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  327. data = np.concatenate(
  328. (
  329. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  330. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  331. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  332. ),
  333. axis=0,
  334. )
  335. print("re-format attention.linear_qkv.weight")
  336. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  337. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  338. data = np.concatenate(
  339. (
  340. qkv_bias[:, 0, :].reshape((n_embed,)),
  341. qkv_bias[:, 1, :].reshape((n_embed,)),
  342. qkv_bias[:, 2, :].reshape((n_embed,)),
  343. ),
  344. axis=0,
  345. )
  346. print("re-format attention.linear_qkv.bias")
  347. # map tensor names
  348. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  349. if new_name is None:
  350. print(f"Can not map tensor {name!r}")
  351. sys.exit()
  352. n_dims = len(data.shape)
  353. data_dtype = data.dtype
  354. # if f32 desired, convert any float16 to float32
  355. if self.ftype == 0 and data_dtype == np.float16:
  356. data = data.astype(np.float32)
  357. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  358. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  359. data = data.astype(np.float32)
  360. # if f16 desired, convert any float32 2-dim weight tensors to float16
  361. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  362. data = data.astype(np.float16)
  363. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  364. self.gguf_writer.add_tensor(new_name, data)
  365. if not has_lm_head and name == "word_embeddings.weight":
  366. self.gguf_writer.add_tensor("output.weight", data)
  367. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  368. class MPTModel(Model):
  369. def set_gguf_parameters(self):
  370. block_count = self.hparams["n_layers"]
  371. self.gguf_writer.add_name(self.dir_model.name)
  372. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  373. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  374. self.gguf_writer.add_block_count(block_count)
  375. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  376. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  377. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  378. self.gguf_writer.add_head_count_kv(kv_n_heads)
  379. self.gguf_writer.add_layer_norm_eps(1e-5)
  380. if self.hparams["attn_config"]["clip_qkv"] is not None:
  381. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  382. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  383. def write_tensors(self):
  384. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  385. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  386. for name, data_torch in self.get_tensors():
  387. # we don't need these
  388. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  389. continue
  390. old_dtype = data_torch.dtype
  391. # convert any unsupported data types to float32
  392. if data_torch.dtype not in (torch.float16, torch.float32):
  393. data_torch = data_torch.to(torch.float32)
  394. data = data_torch.squeeze().numpy()
  395. # map tensor names
  396. if "scales" in name:
  397. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  398. new_name = new_name.replace("scales", "act.scales")
  399. else:
  400. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  401. if new_name is None:
  402. print(f"Can not map tensor {name!r}")
  403. sys.exit()
  404. n_dims = len(data.shape)
  405. data_dtype = data.dtype
  406. # if f32 desired, convert any float16 to float32
  407. if self.ftype == 0 and data_dtype == np.float16:
  408. data = data.astype(np.float32)
  409. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  410. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  411. data = data.astype(np.float32)
  412. # if f16 desired, convert any float32 2-dim weight tensors to float16
  413. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  414. data = data.astype(np.float16)
  415. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  416. self.gguf_writer.add_tensor(new_name, data)
  417. # note: MPT output is tied to (same as) wte in original model;
  418. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  419. if new_name == "token_embd.weight":
  420. self.gguf_writer.add_tensor("output.weight", data)
  421. class BaichuanModel(Model):
  422. def set_vocab(self):
  423. self._set_vocab_sentencepiece()
  424. def set_gguf_parameters(self):
  425. block_count = self.hparams["num_hidden_layers"]
  426. head_count = self.hparams["num_attention_heads"]
  427. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  428. hf_repo = self.hparams.get("_name_or_path", "")
  429. ctx_length = 0
  430. if "max_sequence_length" in self.hparams:
  431. ctx_length = self.hparams["max_sequence_length"]
  432. elif "max_position_embeddings" in self.hparams:
  433. ctx_length = self.hparams["max_position_embeddings"]
  434. elif "model_max_length" in self.hparams:
  435. ctx_length = self.hparams["model_max_length"]
  436. else:
  437. print("gguf: can not find ctx length parameter.")
  438. sys.exit()
  439. self.gguf_writer.add_name(self.dir_model.name)
  440. self.gguf_writer.add_source_hf_repo(hf_repo)
  441. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  442. self.gguf_writer.add_context_length(ctx_length)
  443. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  444. self.gguf_writer.add_block_count(block_count)
  445. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  446. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  447. self.gguf_writer.add_head_count(head_count)
  448. self.gguf_writer.add_head_count_kv(head_count_kv)
  449. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  450. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  451. if self.hparams["rope_scaling"].get("type") == "linear":
  452. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  453. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  454. def write_tensors(self):
  455. # Collect tensors from generator object
  456. model_kv = dict(self.get_tensors())
  457. block_count = self.hparams["num_hidden_layers"]
  458. head_count = self.hparams["num_attention_heads"]
  459. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  460. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  461. for i in range(block_count):
  462. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  463. print(f"Unpacking and permuting layer {i}")
  464. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  465. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  466. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  467. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  468. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  469. self._reverse_hf_part(w, 2)
  470. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  471. for name, data_torch in model_kv.items():
  472. # we don't need these
  473. if name.endswith(".rotary_emb.inv_freq"):
  474. continue
  475. old_dtype = data_torch.dtype
  476. # convert any unsupported data types to float32
  477. if data_torch.dtype not in (torch.float16, torch.float32):
  478. data_torch = data_torch.to(torch.float32)
  479. data = data_torch.squeeze().numpy()
  480. # map tensor names
  481. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  482. if new_name is None:
  483. print(f"Can not map tensor {name!r}")
  484. sys.exit()
  485. n_dims = len(data.shape)
  486. data_dtype = data.dtype
  487. # if f32 desired, convert any float16 to float32
  488. if self.ftype == 0 and data_dtype == np.float16:
  489. data = data.astype(np.float32)
  490. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  491. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  492. data = data.astype(np.float32)
  493. # if f16 desired, convert any float32 2-dim weight tensors to float16
  494. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  495. data = data.astype(np.float16)
  496. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  497. self.gguf_writer.add_tensor(new_name, data)
  498. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  499. if n_kv_head is not None and n_head != n_kv_head:
  500. n_head //= n_kv_head
  501. return (
  502. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  503. .swapaxes(1, 2)
  504. .reshape(weights.shape)
  505. )
  506. def _reverse_hf_permute_part(
  507. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  508. ) -> Tensor:
  509. r = weights.shape[0] // 3
  510. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  511. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  512. r = weights.shape[0] // 3
  513. return weights[r * n_part:r * n_part + r, ...]
  514. class FalconModel(Model):
  515. def set_gguf_parameters(self):
  516. block_count = self.hparams.get("num_hidden_layers")
  517. if block_count is None:
  518. block_count = self.hparams["n_layer"] # old name
  519. n_head = self.hparams.get("num_attention_heads")
  520. if n_head is None:
  521. n_head = self.hparams["n_head"] # old name
  522. n_head_kv = self.hparams.get("num_kv_heads")
  523. if n_head_kv is None:
  524. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  525. self.gguf_writer.add_name("Falcon")
  526. self.gguf_writer.add_context_length(2048) # not in config.json
  527. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  528. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  529. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  530. self.gguf_writer.add_block_count(block_count)
  531. self.gguf_writer.add_head_count(n_head)
  532. self.gguf_writer.add_head_count_kv(n_head_kv)
  533. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  534. self.gguf_writer.add_file_type(self.ftype)
  535. def write_tensors(self):
  536. block_count = self.hparams.get("num_hidden_layers")
  537. if block_count is None:
  538. block_count = self.hparams["n_layer"] # old name
  539. n_head = self.hparams.get("num_attention_heads")
  540. if n_head is None:
  541. n_head = self.hparams["n_head"] # old name
  542. n_head_kv = self.hparams.get("num_kv_heads")
  543. if n_head_kv is None:
  544. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  545. head_dim = self.hparams["hidden_size"] // n_head
  546. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  547. for name, data_torch in self.get_tensors():
  548. old_dtype = data_torch.dtype
  549. # convert any unsupported data types to float32
  550. if data_torch.dtype not in (torch.float16, torch.float32):
  551. data_torch = data_torch.to(torch.float32)
  552. # QKV tensor transform
  553. # The original query_key_value tensor contains n_head_kv "kv groups",
  554. # each consisting of n_head/n_head_kv query weights followed by one key
  555. # and one value weight (shared by all query heads in the kv group).
  556. # This layout makes it a big pain to work with in GGML.
  557. # So we rearrange them here,, so that we have n_head query weights
  558. # followed by n_head_kv key weights followed by n_head_kv value weights,
  559. # in contiguous fashion.
  560. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  561. if "query_key_value" in name:
  562. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  563. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  564. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  565. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  566. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  567. data = data_torch.squeeze().numpy()
  568. # map tensor names
  569. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  570. if new_name is None:
  571. print(f"Can not map tensor {name!r}")
  572. sys.exit()
  573. n_dims = len(data.shape)
  574. data_dtype = data.dtype
  575. # if f32 desired, convert any float16 to float32
  576. if self.ftype == 0 and data_dtype == np.float16:
  577. data = data.astype(np.float32)
  578. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  579. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  580. data = data.astype(np.float32)
  581. # if f16 desired, convert any float32 2-dim weight tensors to float16
  582. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  583. data = data.astype(np.float16)
  584. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  585. self.gguf_writer.add_tensor(new_name, data)
  586. class StarCoderModel(Model):
  587. def set_gguf_parameters(self):
  588. block_count = self.hparams["n_layer"]
  589. self.gguf_writer.add_name("StarCoder")
  590. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  591. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  592. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  593. self.gguf_writer.add_block_count(block_count)
  594. self.gguf_writer.add_head_count(self.hparams["n_head"])
  595. self.gguf_writer.add_head_count_kv(1)
  596. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  597. self.gguf_writer.add_file_type(self.ftype)
  598. class RefactModel(Model):
  599. def set_gguf_parameters(self):
  600. hidden_dim = self.hparams["n_embd"]
  601. inner_dim = 4 * hidden_dim
  602. hidden_dim = int(2 * inner_dim / 3)
  603. multiple_of = 256
  604. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  605. block_count = self.hparams["n_layer"]
  606. self.gguf_writer.add_name("Refact")
  607. # refact uses Alibi. So this is from config.json which might be used by training.
  608. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  609. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  610. self.gguf_writer.add_feed_forward_length(ff_dim)
  611. self.gguf_writer.add_block_count(block_count)
  612. self.gguf_writer.add_head_count(self.hparams["n_head"])
  613. self.gguf_writer.add_head_count_kv(1)
  614. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  615. self.gguf_writer.add_file_type(self.ftype)
  616. def write_tensors(self):
  617. hidden_dim = self.hparams["n_embd"]
  618. inner_dim = 4 * hidden_dim
  619. hidden_dim = int(2 * inner_dim / 3)
  620. multiple_of = 256
  621. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  622. n_head = self.hparams["n_head"]
  623. n_head_kv = 1
  624. head_dim = self.hparams["n_embd"] // n_head
  625. block_count = self.hparams["n_layer"]
  626. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  627. tensors = dict(self.get_tensors())
  628. for i in range(block_count):
  629. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  630. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  631. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  632. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  633. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  634. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  635. del tensors[f"transformer.h.{i}.attn.q.weight"]
  636. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  637. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  638. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  639. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  640. for name, data_torch in tensors.items():
  641. old_dtype = data_torch.dtype
  642. # convert any unsupported data types to float32
  643. if data_torch.dtype not in (torch.float16, torch.float32):
  644. data_torch = data_torch.to(torch.float32)
  645. data = data_torch.squeeze().numpy()
  646. # map tensor names
  647. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  648. if new_name is None:
  649. print(f"Can not map tensor {name!r}")
  650. sys.exit()
  651. n_dims = len(data.shape)
  652. data_dtype = data.dtype
  653. # if f32 desired, convert any float16 to float32
  654. if self.ftype == 0 and data_dtype == np.float16:
  655. data = data.astype(np.float32)
  656. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  657. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  658. data = data.astype(np.float32)
  659. # if f16 desired, convert any float32 2-dim weight tensors to float16
  660. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  661. data = data.astype(np.float16)
  662. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  663. self.gguf_writer.add_tensor(new_name, data)
  664. class PersimmonModel(Model):
  665. def set_gguf_parameters(self):
  666. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  667. head_count = self.hparams["num_attention_heads"]
  668. head_count_kv = head_count
  669. hidden_size = self.hparams["hidden_size"]
  670. self.gguf_writer.add_name('persimmon-8b-chat')
  671. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  672. self.gguf_writer.add_embedding_length(hidden_size)
  673. self.gguf_writer.add_block_count(block_count)
  674. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  675. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  676. # than the head size?
  677. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  678. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  679. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  680. self.gguf_writer.add_head_count(head_count)
  681. self.gguf_writer.add_head_count_kv(head_count_kv)
  682. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  683. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  684. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  685. def set_vocab(self):
  686. self._set_vocab_sentencepiece()
  687. # self.gguf_writer.add_bos_token_id(71013)
  688. # self.gguf_writer.add_eos_token_id(71013)
  689. def write_tensors(self):
  690. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  691. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  692. for name, data_torch in self.get_tensors():
  693. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  694. continue
  695. old_dtype = data_torch.dtype
  696. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  697. data = data_torch.to(torch.float32).squeeze().numpy()
  698. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  699. if new_name is None:
  700. print(f"Can not map tensor {name!r}")
  701. sys.exit()
  702. n_dims = len(data.shape)
  703. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  704. self.gguf_writer.add_tensor(new_name, data)
  705. class StableLMModel(Model):
  706. def set_gguf_parameters(self):
  707. hparams = self.hparams
  708. block_count = hparams["num_hidden_layers"]
  709. self.gguf_writer.add_name(self.dir_model.name)
  710. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  711. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  712. self.gguf_writer.add_block_count(block_count)
  713. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  714. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  715. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  716. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  717. self.gguf_writer.add_layer_norm_eps(1e-5)
  718. class MixtralModel(Model):
  719. def set_vocab(self):
  720. self._set_vocab_sentencepiece()
  721. class QwenModel(Model):
  722. @staticmethod
  723. def token_bytes_to_string(b):
  724. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  725. byte_encoder = bytes_to_unicode()
  726. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  727. @staticmethod
  728. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]:
  729. parts = [bytes([b]) for b in token]
  730. while True:
  731. min_idx = None
  732. min_rank = None
  733. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  734. rank = mergeable_ranks.get(pair[0] + pair[1])
  735. if rank is not None and (min_rank is None or rank < min_rank):
  736. min_idx = i
  737. min_rank = rank
  738. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  739. break
  740. assert min_idx is not None
  741. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  742. return parts
  743. def set_vocab(self):
  744. dir_model = self.dir_model
  745. hparams = self.hparams
  746. tokens: list[bytearray] = []
  747. toktypes: list[int] = []
  748. from transformers import AutoTokenizer
  749. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  750. vocab_size = hparams["vocab_size"]
  751. assert max(tokenizer.get_vocab().values()) < vocab_size
  752. merges = []
  753. vocab = {}
  754. mergeable_ranks = tokenizer.mergeable_ranks
  755. for token, rank in mergeable_ranks.items():
  756. vocab[self.token_bytes_to_string(token)] = rank
  757. if len(token) == 1:
  758. continue
  759. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  760. assert len(merged) == 2
  761. merges.append(' '.join(map(self.token_bytes_to_string, merged)))
  762. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
  763. added_vocab = tokenizer.special_tokens
  764. for i in range(vocab_size):
  765. if i not in reverse_vocab:
  766. pad_token = f"[PAD{i}]".encode("utf-8")
  767. tokens.append(bytearray(pad_token))
  768. toktypes.append(gguf.TokenType.USER_DEFINED)
  769. elif reverse_vocab[i] in added_vocab:
  770. tokens.append(reverse_vocab[i])
  771. toktypes.append(gguf.TokenType.CONTROL)
  772. else:
  773. tokens.append(reverse_vocab[i])
  774. toktypes.append(gguf.TokenType.NORMAL)
  775. self.gguf_writer.add_tokenizer_model("gpt2")
  776. self.gguf_writer.add_token_list(tokens)
  777. self.gguf_writer.add_token_types(toktypes)
  778. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  779. special_vocab.merges = merges
  780. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  781. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  782. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  783. special_vocab.add_to_gguf(self.gguf_writer)
  784. def set_gguf_parameters(self):
  785. self.gguf_writer.add_name("Qwen")
  786. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  787. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  788. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  789. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  790. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  791. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  792. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  793. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  794. def write_tensors(self):
  795. block_count = self.hparams["num_hidden_layers"]
  796. model_kv = dict(self.get_tensors())
  797. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  798. for name, data_torch in model_kv.items():
  799. # we don't need these
  800. if name.endswith(".rotary_emb.inv_freq"):
  801. continue
  802. old_dtype = data_torch.dtype
  803. # convert any unsupported data types to float32
  804. if data_torch.dtype not in (torch.float16, torch.float32):
  805. data_torch = data_torch.to(torch.float32)
  806. data = data_torch.squeeze().numpy()
  807. # map tensor names
  808. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  809. if new_name is None:
  810. print(f"Can not map tensor {name!r}")
  811. sys.exit()
  812. n_dims = len(data.shape)
  813. data_dtype = data.dtype
  814. # if f32 desired, convert any float16 to float32
  815. if self.ftype == 0 and data_dtype == np.float16:
  816. data = data.astype(np.float32)
  817. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  818. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  819. data = data.astype(np.float32)
  820. # if f16 desired, convert any float32 2-dim weight tensors to float16
  821. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  822. data = data.astype(np.float16)
  823. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  824. self.gguf_writer.add_tensor(new_name, data)
  825. class GPT2Model(Model):
  826. def set_gguf_parameters(self):
  827. self.gguf_writer.add_name(self.dir_model.name)
  828. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  829. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  830. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  831. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  832. self.gguf_writer.add_head_count(self.hparams["n_head"])
  833. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  834. self.gguf_writer.add_file_type(self.ftype)
  835. def write_tensors(self):
  836. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  837. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  838. for name, data_torch in self.get_tensors():
  839. # we don't need these
  840. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias")):
  841. continue
  842. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  843. data_torch = data_torch.transpose(1, 0)
  844. old_dtype = data_torch.dtype
  845. # convert any unsupported data types to float32
  846. if data_torch.dtype not in (torch.float16, torch.float32):
  847. data_torch = data_torch.to(torch.float32)
  848. data = data_torch.squeeze().numpy()
  849. # map tensor names
  850. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  851. if new_name is None:
  852. print(f"Can not map tensor {name!r}")
  853. sys.exit()
  854. n_dims = len(data.shape)
  855. data_dtype = data.dtype
  856. # if f32 desired, convert any float16 to float32
  857. if self.ftype == 0 and data_dtype == np.float16:
  858. data = data.astype(np.float32)
  859. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  860. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  861. data = data.astype(np.float32)
  862. # if f16 desired, convert any float32 2-dim weight tensors to float16
  863. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  864. data = data.astype(np.float16)
  865. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  866. self.gguf_writer.add_tensor(new_name, data)
  867. # note: GPT2 output is tied to (same as) wte in original model
  868. if new_name == "token_embd.weight":
  869. print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  870. self.gguf_writer.add_tensor("output.weight", data)
  871. class Phi2Model(Model):
  872. def set_gguf_parameters(self):
  873. block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
  874. rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
  875. n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
  876. n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
  877. self.gguf_writer.add_name("Phi2")
  878. self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
  879. self.gguf_writer.add_embedding_length(n_embd)
  880. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  881. self.gguf_writer.add_block_count(block_count)
  882. self.gguf_writer.add_head_count(n_head)
  883. self.gguf_writer.add_head_count_kv(n_head)
  884. self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
  885. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  886. self.gguf_writer.add_file_type(self.ftype)
  887. self.gguf_writer.add_add_bos_token(False)
  888. class PlamoModel(Model):
  889. def set_vocab(self):
  890. self._set_vocab_sentencepiece()
  891. def set_gguf_parameters(self):
  892. hparams = self.hparams
  893. block_count = hparams["num_hidden_layers"]
  894. self.gguf_writer.add_name("PLaMo")
  895. self.gguf_writer.add_context_length(4096) # not in config.json
  896. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  897. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  898. self.gguf_writer.add_block_count(block_count)
  899. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  900. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  901. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  902. def shuffle_attn_q_weight(self, data_torch):
  903. assert data_torch.size() == (5120, 5120)
  904. data_torch = data_torch.reshape(8, 5, 128, 5120)
  905. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  906. data_torch = torch.reshape(data_torch, (5120, 5120))
  907. return data_torch
  908. def shuffle_attn_output_weight(self, data_torch):
  909. assert data_torch.size() == (5120, 5120)
  910. data_torch = data_torch.reshape(5120, 8, 5, 128)
  911. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  912. data_torch = torch.reshape(data_torch, (5120, 5120))
  913. return data_torch
  914. def write_tensors(self):
  915. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  916. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  917. for name, data_torch in self.get_tensors():
  918. if "self_attn.rotary_emb.inv_freq" in name:
  919. continue
  920. # map tensor names
  921. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  922. if new_name is None:
  923. print(f"Can not map tensor {name!r}")
  924. sys.exit()
  925. # shuffle for broadcasting of gqa in ggml_mul_mat
  926. if new_name.endswith("attn_q.weight"):
  927. data_torch = self.shuffle_attn_q_weight(data_torch)
  928. elif new_name.endswith("attn_output.weight"):
  929. data_torch = self.shuffle_attn_output_weight(data_torch)
  930. old_dtype = data_torch.dtype
  931. # convert any unsupported data types to float32
  932. if data_torch.dtype not in (torch.float16, torch.float32):
  933. data_torch = data_torch.to(torch.float32)
  934. data = data_torch.squeeze().numpy()
  935. n_dims = len(data.shape)
  936. data_dtype = data.dtype
  937. # if f32 desired, convert any float16 to float32
  938. if self.ftype == 0 and data_dtype == np.float16:
  939. data = data.astype(np.float32)
  940. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  941. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  942. data = data.astype(np.float32)
  943. # if f16 desired, convert any float32 2-dim weight tensors to float16
  944. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  945. data = data.astype(np.float16)
  946. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  947. self.gguf_writer.add_tensor(new_name, data)
  948. ###### CONVERSION LOGIC ######
  949. def parse_args() -> argparse.Namespace:
  950. parser = argparse.ArgumentParser(
  951. description="Convert a huggingface model to a GGML compatible file")
  952. parser.add_argument(
  953. "--vocab-only", action="store_true",
  954. help="extract only the vocab",
  955. )
  956. parser.add_argument(
  957. "--awq-path", type=Path, default=None,
  958. help="Path to scale awq cache file")
  959. parser.add_argument(
  960. "--outfile", type=Path,
  961. help="path to write to; default: based on input",
  962. )
  963. parser.add_argument(
  964. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  965. help="output format - use f32 for float32, f16 for float16",
  966. )
  967. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  968. parser.add_argument(
  969. "model", type=Path,
  970. help="directory containing model file",
  971. )
  972. return parser.parse_args()
  973. def main() -> None:
  974. args = parse_args()
  975. dir_model = args.model
  976. if args.awq_path:
  977. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  978. from awq.apply_awq import add_scale_weights
  979. tmp_model_path = args.model / "weighted_model"
  980. dir_model = tmp_model_path
  981. if tmp_model_path.is_dir():
  982. print(f"{tmp_model_path} exists as a weighted model.")
  983. else:
  984. tmp_model_path.mkdir(parents=True, exist_ok=True)
  985. print("Saving new weighted model ...")
  986. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  987. print(f"Saved weighted model at {tmp_model_path}.")
  988. if not dir_model.is_dir():
  989. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  990. sys.exit(1)
  991. ftype_map = {
  992. "f32": gguf.GGMLQuantizationType.F32,
  993. "f16": gguf.GGMLQuantizationType.F16,
  994. }
  995. if args.outfile is not None:
  996. fname_out = args.outfile
  997. else:
  998. # output in the same directory as the model by default
  999. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  1000. print(f"Loading model: {dir_model.name}")
  1001. hparams = Model.load_hparams(dir_model)
  1002. with torch.inference_mode():
  1003. model_class = Model.from_model_architecture(hparams["architectures"][0])
  1004. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  1005. print("Set model parameters")
  1006. model_instance.set_gguf_parameters()
  1007. print("Set model tokenizer")
  1008. model_instance.set_vocab()
  1009. if args.vocab_only:
  1010. print(f"Exporting model vocab to '{fname_out}'")
  1011. model_instance.write_vocab()
  1012. else:
  1013. print(f"Exporting model to '{fname_out}'")
  1014. model_instance.write()
  1015. print(f"Model successfully exported to '{fname_out}'")
  1016. if __name__ == '__main__':
  1017. main()