convert-hf-to-gguf.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. ###### MODEL DEFINITIONS ######
  20. class SentencePieceTokenTypes(IntEnum):
  21. NORMAL = 1
  22. UNKNOWN = 2
  23. CONTROL = 3
  24. USER_DEFINED = 4
  25. UNUSED = 5
  26. BYTE = 6
  27. class Model:
  28. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  29. self.dir_model = dir_model
  30. self.ftype = ftype
  31. self.fname_out = fname_out
  32. self.is_big_endian = is_big_endian
  33. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  34. self.is_safetensors = self._is_model_safetensors()
  35. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  36. self.part_names = self._get_part_names()
  37. self.hparams = Model.load_hparams(self.dir_model)
  38. self.model_arch = self._get_model_architecture()
  39. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess)
  40. def set_vocab(self):
  41. self._set_vocab_gpt2()
  42. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  43. for part_name in self.part_names:
  44. print(f"gguf: loading model part '{part_name}'")
  45. ctx: ContextManager[Any]
  46. if self.is_safetensors:
  47. from safetensors import safe_open
  48. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  49. else:
  50. ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
  51. with ctx as model_part:
  52. for name in model_part.keys():
  53. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  54. yield name, data
  55. def set_gguf_parameters(self):
  56. self.gguf_writer.add_name(self.dir_model.name)
  57. self.gguf_writer.add_block_count(self.hparams.get(
  58. "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
  59. ))
  60. if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
  61. self.gguf_writer.add_context_length(n_ctx)
  62. if (n_embd := self.hparams.get("hidden_size")) is not None:
  63. self.gguf_writer.add_embedding_length(n_embd)
  64. if (n_ff := self.hparams.get("intermediate_size")) is not None:
  65. self.gguf_writer.add_feed_forward_length(n_ff)
  66. if (n_head := self.hparams.get("num_attention_head")) is not None:
  67. self.gguf_writer.add_head_count(n_head)
  68. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  69. def write_tensors(self):
  70. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  71. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  72. for name, data_torch in self.get_tensors():
  73. # we don't need these
  74. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  75. continue
  76. old_dtype = data_torch.dtype
  77. # convert any unsupported data types to float32
  78. if data_torch.dtype not in (torch.float16, torch.float32):
  79. data_torch = data_torch.to(torch.float32)
  80. data = data_torch.squeeze().numpy()
  81. # map tensor names
  82. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  83. if new_name is None:
  84. print(f"Can not map tensor {name!r}")
  85. sys.exit()
  86. n_dims = len(data.shape)
  87. data_dtype = data.dtype
  88. # if f32 desired, convert any float16 to float32
  89. if self.ftype == 0 and data_dtype == np.float16:
  90. data = data.astype(np.float32)
  91. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  92. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  93. data = data.astype(np.float32)
  94. # if f16 desired, convert any float32 2-dim weight tensors to float16
  95. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  96. data = data.astype(np.float16)
  97. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  98. self.gguf_writer.add_tensor(new_name, data)
  99. def write(self):
  100. self.write_tensors()
  101. self.gguf_writer.write_header_to_file()
  102. self.gguf_writer.write_kv_data_to_file()
  103. self.gguf_writer.write_tensors_to_file()
  104. self.gguf_writer.close()
  105. def write_vocab(self):
  106. self.gguf_writer.write_header_to_file()
  107. self.gguf_writer.write_kv_data_to_file()
  108. self.gguf_writer.close()
  109. @staticmethod
  110. def count_model_parts(dir_model: Path, prefix: str) -> int:
  111. num_parts = 0
  112. for filename in os.listdir(dir_model):
  113. if filename.endswith(prefix):
  114. num_parts += 1
  115. return num_parts
  116. @staticmethod
  117. def load_hparams(dir_model):
  118. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  119. return json.load(f)
  120. @staticmethod
  121. def from_model_architecture(model_architecture):
  122. if model_architecture == "GPTNeoXForCausalLM":
  123. return GPTNeoXModel
  124. if model_architecture == "BloomForCausalLM":
  125. return BloomModel
  126. if model_architecture == "MPTForCausalLM":
  127. return MPTModel
  128. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  129. return BaichuanModel
  130. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  131. return FalconModel
  132. if model_architecture == "GPTBigCodeForCausalLM":
  133. return StarCoderModel
  134. if model_architecture == "GPTRefactForCausalLM":
  135. return RefactModel
  136. if model_architecture == "PersimmonForCausalLM":
  137. return PersimmonModel
  138. if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  139. return StableLMModel
  140. return Model
  141. def _is_model_safetensors(self) -> bool:
  142. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  143. def _get_part_names(self):
  144. if self.is_safetensors:
  145. if self.num_parts == 1: # there's only one .safetensors file
  146. return ("model.safetensors",)
  147. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  148. if self.num_parts == 1: # there's only one .bin file
  149. return ("pytorch_model.bin",)
  150. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  151. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  152. arch = self.hparams["architectures"][0]
  153. if arch == "GPTNeoXForCausalLM":
  154. return gguf.MODEL_ARCH.GPTNEOX
  155. if arch == "BloomForCausalLM":
  156. return gguf.MODEL_ARCH.BLOOM
  157. if arch == "MPTForCausalLM":
  158. return gguf.MODEL_ARCH.MPT
  159. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  160. return gguf.MODEL_ARCH.BAICHUAN
  161. if arch in ("FalconForCausalLM", "RWForCausalLM"):
  162. return gguf.MODEL_ARCH.FALCON
  163. if arch == "GPTBigCodeForCausalLM":
  164. return gguf.MODEL_ARCH.STARCODER
  165. if arch == "GPTRefactForCausalLM":
  166. return gguf.MODEL_ARCH.REFACT
  167. if arch == "PersimmonForCausalLM":
  168. return gguf.MODEL_ARCH.PERSIMMON
  169. if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
  170. return gguf.MODEL_ARCH.STABLELM
  171. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  172. def _set_vocab_gpt2(self):
  173. dir_model = self.dir_model
  174. hparams = self.hparams
  175. tokens: list[bytearray] = []
  176. toktypes: list[int] = []
  177. from transformers import AutoTokenizer # type: ignore[attr-defined]
  178. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  179. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  180. assert max(tokenizer.vocab.values()) < vocab_size
  181. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  182. added_vocab = tokenizer.get_added_vocab()
  183. for i in range(vocab_size):
  184. if i not in reverse_vocab:
  185. pad_token = f"[PAD{i}]".encode('utf-8')
  186. tokens.append(bytearray(pad_token))
  187. toktypes.append(gguf.TokenType.USER_DEFINED)
  188. elif reverse_vocab[i] in added_vocab:
  189. tokens.append(reverse_vocab[i])
  190. if tokenizer.added_tokens_decoder[i].special:
  191. toktypes.append(gguf.TokenType.CONTROL)
  192. else:
  193. toktypes.append(gguf.TokenType.USER_DEFINED)
  194. else:
  195. tokens.append(reverse_vocab[i])
  196. toktypes.append(gguf.TokenType.NORMAL)
  197. self.gguf_writer.add_tokenizer_model("gpt2")
  198. self.gguf_writer.add_token_list(tokens)
  199. self.gguf_writer.add_token_types(toktypes)
  200. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  201. special_vocab.add_to_gguf(self.gguf_writer)
  202. def _set_vocab_sentencepiece(self):
  203. from sentencepiece import SentencePieceProcessor
  204. tokenizer_path = self.dir_model / 'tokenizer.model'
  205. tokens: list[bytes] = []
  206. scores: list[float] = []
  207. toktypes: list[int] = []
  208. if not tokenizer_path.is_file():
  209. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  210. sys.exit(1)
  211. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  212. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  213. for token_id in range(vocab_size):
  214. piece = tokenizer.id_to_piece(token_id)
  215. text = piece.encode("utf-8")
  216. score = tokenizer.get_score(token_id)
  217. toktype = SentencePieceTokenTypes.NORMAL
  218. if tokenizer.is_unknown(token_id):
  219. toktype = SentencePieceTokenTypes.UNKNOWN
  220. elif tokenizer.is_control(token_id):
  221. toktype = SentencePieceTokenTypes.CONTROL
  222. elif tokenizer.is_unused(token_id):
  223. toktype = SentencePieceTokenTypes.UNUSED
  224. elif tokenizer.is_byte(token_id):
  225. toktype = SentencePieceTokenTypes.BYTE
  226. tokens.append(text)
  227. scores.append(score)
  228. toktypes.append(toktype)
  229. added_tokens_file = self.dir_model / 'added_tokens.json'
  230. if added_tokens_file.is_file():
  231. with open(added_tokens_file, "r", encoding="utf-8") as f:
  232. added_tokens_json = json.load(f)
  233. for key in added_tokens_json:
  234. tokens.append(key.encode("utf-8"))
  235. scores.append(-1000.0)
  236. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  237. self.gguf_writer.add_tokenizer_model("llama")
  238. self.gguf_writer.add_token_list(tokens)
  239. self.gguf_writer.add_token_scores(scores)
  240. self.gguf_writer.add_token_types(toktypes)
  241. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  242. special_vocab.add_to_gguf(self.gguf_writer)
  243. class GPTNeoXModel(Model):
  244. def set_gguf_parameters(self):
  245. block_count = self.hparams["num_hidden_layers"]
  246. self.gguf_writer.add_name(self.dir_model.name)
  247. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  248. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  249. self.gguf_writer.add_block_count(block_count)
  250. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  251. self.gguf_writer.add_rope_dimension_count(
  252. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  253. )
  254. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  255. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  256. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  257. class BloomModel(Model):
  258. def set_gguf_parameters(self):
  259. self.gguf_writer.add_name("Bloom")
  260. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  261. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  262. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  263. self.gguf_writer.add_embedding_length(n_embed)
  264. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  265. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  266. self.gguf_writer.add_head_count(n_head)
  267. self.gguf_writer.add_head_count_kv(n_head)
  268. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  269. self.gguf_writer.add_file_type(self.ftype)
  270. def write_tensors(self):
  271. block_count = self.hparams["n_layer"]
  272. tensors = dict(self.get_tensors())
  273. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  274. has_lm_head = True
  275. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  276. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  277. for name, data_torch in tensors.items():
  278. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  279. has_lm_head = False
  280. name = re.sub(r'transformer\.', '', name)
  281. old_dtype = data_torch.dtype
  282. # convert any unsupported data types to float32
  283. if data_torch.dtype not in (torch.float16, torch.float32):
  284. data_torch = data_torch.to(torch.float32)
  285. data = data_torch.squeeze().numpy()
  286. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  287. # Map bloom-style qkv_linear to gpt-style qkv_linear
  288. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  289. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  290. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  291. data = np.concatenate(
  292. (
  293. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  294. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  295. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  296. ),
  297. axis=0,
  298. )
  299. print("re-format attention.linear_qkv.weight")
  300. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  301. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  302. data = np.concatenate(
  303. (
  304. qkv_bias[:, 0, :].reshape((n_embed,)),
  305. qkv_bias[:, 1, :].reshape((n_embed,)),
  306. qkv_bias[:, 2, :].reshape((n_embed,)),
  307. ),
  308. axis=0,
  309. )
  310. print("re-format attention.linear_qkv.bias")
  311. # map tensor names
  312. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  313. if new_name is None:
  314. print(f"Can not map tensor {name!r}")
  315. sys.exit()
  316. n_dims = len(data.shape)
  317. data_dtype = data.dtype
  318. # if f32 desired, convert any float16 to float32
  319. if self.ftype == 0 and data_dtype == np.float16:
  320. data = data.astype(np.float32)
  321. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  322. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  323. data = data.astype(np.float32)
  324. # if f16 desired, convert any float32 2-dim weight tensors to float16
  325. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  326. data = data.astype(np.float16)
  327. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  328. self.gguf_writer.add_tensor(new_name, data)
  329. if not has_lm_head and name == "word_embeddings.weight":
  330. self.gguf_writer.add_tensor("output.weight", data)
  331. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  332. class MPTModel(Model):
  333. def set_gguf_parameters(self):
  334. block_count = self.hparams["n_layers"]
  335. self.gguf_writer.add_name(self.dir_model.name)
  336. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  337. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  338. self.gguf_writer.add_block_count(block_count)
  339. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  340. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  341. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  342. self.gguf_writer.add_head_count_kv(kv_n_heads)
  343. self.gguf_writer.add_layer_norm_eps(1e-5)
  344. if self.hparams["attn_config"]["clip_qkv"] is not None:
  345. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  346. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  347. def write_tensors(self):
  348. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  349. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  350. for name, data_torch in self.get_tensors():
  351. # we don't need these
  352. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  353. continue
  354. old_dtype = data_torch.dtype
  355. # convert any unsupported data types to float32
  356. if data_torch.dtype not in (torch.float16, torch.float32):
  357. data_torch = data_torch.to(torch.float32)
  358. data = data_torch.squeeze().numpy()
  359. # map tensor names
  360. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  361. if new_name is None:
  362. print(f"Can not map tensor {name!r}")
  363. sys.exit()
  364. n_dims = len(data.shape)
  365. data_dtype = data.dtype
  366. # if f32 desired, convert any float16 to float32
  367. if self.ftype == 0 and data_dtype == np.float16:
  368. data = data.astype(np.float32)
  369. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  370. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  371. data = data.astype(np.float32)
  372. # if f16 desired, convert any float32 2-dim weight tensors to float16
  373. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  374. data = data.astype(np.float16)
  375. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  376. self.gguf_writer.add_tensor(new_name, data)
  377. # note: MPT output is tied to (same as) wte in original model;
  378. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  379. if new_name == "token_embd.weight":
  380. self.gguf_writer.add_tensor("output.weight", data)
  381. class BaichuanModel(Model):
  382. def set_vocab(self):
  383. self._set_vocab_sentencepiece()
  384. def set_gguf_parameters(self):
  385. block_count = self.hparams["num_hidden_layers"]
  386. head_count = self.hparams["num_attention_heads"]
  387. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  388. hf_repo = self.hparams.get("_name_or_path", "")
  389. ctx_length = 0
  390. if "max_sequence_length" in self.hparams:
  391. ctx_length = self.hparams["max_sequence_length"]
  392. elif "max_position_embeddings" in self.hparams:
  393. ctx_length = self.hparams["max_position_embeddings"]
  394. elif "model_max_length" in self.hparams:
  395. ctx_length = self.hparams["model_max_length"]
  396. else:
  397. print("gguf: can not find ctx length parameter.")
  398. sys.exit()
  399. self.gguf_writer.add_name(self.dir_model.name)
  400. self.gguf_writer.add_source_hf_repo(hf_repo)
  401. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  402. self.gguf_writer.add_context_length(ctx_length)
  403. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  404. self.gguf_writer.add_block_count(block_count)
  405. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  406. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  407. self.gguf_writer.add_head_count(head_count)
  408. self.gguf_writer.add_head_count_kv(head_count_kv)
  409. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  410. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  411. if self.hparams["rope_scaling"].get("type") == "linear":
  412. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  413. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  414. def write_tensors(self):
  415. # Collect tensors from generator object
  416. model_kv = dict(self.get_tensors())
  417. block_count = self.hparams["num_hidden_layers"]
  418. head_count = self.hparams["num_attention_heads"]
  419. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  420. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  421. for i in range(block_count):
  422. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  423. print(f"Unpacking and permuting layer {i}")
  424. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  425. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  426. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  427. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  428. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  429. self._reverse_hf_part(w, 2)
  430. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  431. for name, data_torch in model_kv.items():
  432. # we don't need these
  433. if name.endswith(".rotary_emb.inv_freq"):
  434. continue
  435. old_dtype = data_torch.dtype
  436. # convert any unsupported data types to float32
  437. if data_torch.dtype not in (torch.float16, torch.float32):
  438. data_torch = data_torch.to(torch.float32)
  439. data = data_torch.squeeze().numpy()
  440. # map tensor names
  441. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  442. if new_name is None:
  443. print(f"Can not map tensor {name!r}")
  444. sys.exit()
  445. n_dims = len(data.shape)
  446. data_dtype = data.dtype
  447. # if f32 desired, convert any float16 to float32
  448. if self.ftype == 0 and data_dtype == np.float16:
  449. data = data.astype(np.float32)
  450. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  451. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  452. data = data.astype(np.float32)
  453. # if f16 desired, convert any float32 2-dim weight tensors to float16
  454. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  455. data = data.astype(np.float16)
  456. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  457. self.gguf_writer.add_tensor(new_name, data)
  458. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  459. if n_kv_head is not None and n_head != n_kv_head:
  460. n_head //= n_kv_head
  461. return (
  462. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  463. .swapaxes(1, 2)
  464. .reshape(weights.shape)
  465. )
  466. def _reverse_hf_permute_part(
  467. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  468. ) -> Tensor:
  469. r = weights.shape[0] // 3
  470. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  471. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  472. r = weights.shape[0] // 3
  473. return weights[r * n_part:r * n_part + r, ...]
  474. class FalconModel(Model):
  475. def set_gguf_parameters(self):
  476. block_count = self.hparams.get("num_hidden_layers")
  477. if block_count is None:
  478. block_count = self.hparams["n_layer"] # old name
  479. n_head = self.hparams.get("num_attention_heads")
  480. if n_head is None:
  481. n_head = self.hparams["n_head"] # old name
  482. n_head_kv = self.hparams.get("num_kv_heads")
  483. if n_head_kv is None:
  484. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  485. self.gguf_writer.add_name("Falcon")
  486. self.gguf_writer.add_context_length(2048) # not in config.json
  487. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  488. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  489. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  490. self.gguf_writer.add_block_count(block_count)
  491. self.gguf_writer.add_head_count(n_head)
  492. self.gguf_writer.add_head_count_kv(n_head_kv)
  493. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  494. self.gguf_writer.add_file_type(self.ftype)
  495. def write_tensors(self):
  496. block_count = self.hparams.get("num_hidden_layers")
  497. if block_count is None:
  498. block_count = self.hparams["n_layer"] # old name
  499. n_head = self.hparams.get("num_attention_heads")
  500. if n_head is None:
  501. n_head = self.hparams["n_head"] # old name
  502. n_head_kv = self.hparams.get("num_kv_heads")
  503. if n_head_kv is None:
  504. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  505. head_dim = self.hparams["hidden_size"] // n_head
  506. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  507. for name, data_torch in self.get_tensors():
  508. old_dtype = data_torch.dtype
  509. # convert any unsupported data types to float32
  510. if data_torch.dtype not in (torch.float16, torch.float32):
  511. data_torch = data_torch.to(torch.float32)
  512. # QKV tensor transform
  513. # The original query_key_value tensor contains n_head_kv "kv groups",
  514. # each consisting of n_head/n_head_kv query weights followed by one key
  515. # and one value weight (shared by all query heads in the kv group).
  516. # This layout makes it a big pain to work with in GGML.
  517. # So we rearrange them here,, so that we have n_head query weights
  518. # followed by n_head_kv key weights followed by n_head_kv value weights,
  519. # in contiguous fashion.
  520. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  521. if "query_key_value" in name:
  522. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  523. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  524. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  525. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  526. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  527. data = data_torch.squeeze().numpy()
  528. # map tensor names
  529. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  530. if new_name is None:
  531. print(f"Can not map tensor {name!r}")
  532. sys.exit()
  533. n_dims = len(data.shape)
  534. data_dtype = data.dtype
  535. # if f32 desired, convert any float16 to float32
  536. if self.ftype == 0 and data_dtype == np.float16:
  537. data = data.astype(np.float32)
  538. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  539. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  540. data = data.astype(np.float32)
  541. # if f16 desired, convert any float32 2-dim weight tensors to float16
  542. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  543. data = data.astype(np.float16)
  544. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  545. self.gguf_writer.add_tensor(new_name, data)
  546. class StarCoderModel(Model):
  547. def set_gguf_parameters(self):
  548. block_count = self.hparams["n_layer"]
  549. self.gguf_writer.add_name("StarCoder")
  550. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  551. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  552. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  553. self.gguf_writer.add_block_count(block_count)
  554. self.gguf_writer.add_head_count(self.hparams["n_head"])
  555. self.gguf_writer.add_head_count_kv(1)
  556. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  557. self.gguf_writer.add_file_type(self.ftype)
  558. class RefactModel(Model):
  559. def set_gguf_parameters(self):
  560. hidden_dim = self.hparams["n_embd"]
  561. inner_dim = 4 * hidden_dim
  562. hidden_dim = int(2 * inner_dim / 3)
  563. multiple_of = 256
  564. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  565. block_count = self.hparams["n_layer"]
  566. self.gguf_writer.add_name("Refact")
  567. # refact uses Alibi. So this is from config.json which might be used by training.
  568. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  569. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  570. self.gguf_writer.add_feed_forward_length(ff_dim)
  571. self.gguf_writer.add_block_count(block_count)
  572. self.gguf_writer.add_head_count(self.hparams["n_head"])
  573. self.gguf_writer.add_head_count_kv(1)
  574. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  575. self.gguf_writer.add_file_type(self.ftype)
  576. def write_tensors(self):
  577. hidden_dim = self.hparams["n_embd"]
  578. inner_dim = 4 * hidden_dim
  579. hidden_dim = int(2 * inner_dim / 3)
  580. multiple_of = 256
  581. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  582. n_head = self.hparams["n_head"]
  583. n_head_kv = 1
  584. head_dim = self.hparams["n_embd"] // n_head
  585. block_count = self.hparams["n_layer"]
  586. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  587. tensors = dict(self.get_tensors())
  588. for i in range(block_count):
  589. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  590. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  591. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  592. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  593. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  594. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  595. del tensors[f"transformer.h.{i}.attn.q.weight"]
  596. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  597. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  598. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  599. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  600. for name, data_torch in tensors.items():
  601. old_dtype = data_torch.dtype
  602. # convert any unsupported data types to float32
  603. if data_torch.dtype not in (torch.float16, torch.float32):
  604. data_torch = data_torch.to(torch.float32)
  605. data = data_torch.squeeze().numpy()
  606. # map tensor names
  607. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  608. if new_name is None:
  609. print(f"Can not map tensor {name!r}")
  610. sys.exit()
  611. n_dims = len(data.shape)
  612. data_dtype = data.dtype
  613. # if f32 desired, convert any float16 to float32
  614. if self.ftype == 0 and data_dtype == np.float16:
  615. data = data.astype(np.float32)
  616. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  617. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  618. data = data.astype(np.float32)
  619. # if f16 desired, convert any float32 2-dim weight tensors to float16
  620. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  621. data = data.astype(np.float16)
  622. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  623. self.gguf_writer.add_tensor(new_name, data)
  624. class PersimmonModel(Model):
  625. def set_gguf_parameters(self):
  626. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  627. head_count = self.hparams["num_attention_heads"]
  628. head_count_kv = head_count
  629. hidden_size = self.hparams["hidden_size"]
  630. self.gguf_writer.add_name('persimmon-8b-chat')
  631. self.gguf_writer.add_embedding_length(hidden_size)
  632. self.gguf_writer.add_block_count(block_count)
  633. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  634. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  635. self.gguf_writer.add_head_count(head_count)
  636. self.gguf_writer.add_head_count_kv(head_count_kv)
  637. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  638. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  639. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  640. def set_vocab(self):
  641. self._set_vocab_sentencepiece()
  642. # self.gguf_writer.add_bos_token_id(71013)
  643. # self.gguf_writer.add_eos_token_id(71013)
  644. def write_tensors(self):
  645. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  646. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  647. for name, data_torch in self.get_tensors():
  648. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  649. continue
  650. old_dtype = data_torch.dtype
  651. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  652. data = data_torch.to(torch.float32).squeeze().numpy()
  653. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  654. if new_name is None:
  655. print(f"Can not map tensor {name!r}")
  656. sys.exit()
  657. n_dims = len(data.shape)
  658. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  659. self.gguf_writer.add_tensor(new_name, data)
  660. class StableLMModel(Model):
  661. def set_gguf_parameters(self):
  662. hparams = self.hparams
  663. block_count = hparams["num_hidden_layers"]
  664. self.gguf_writer.add_name(dir_model.name)
  665. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  666. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  667. self.gguf_writer.add_block_count(block_count)
  668. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  669. self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
  670. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  671. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  672. self.gguf_writer.add_layer_norm_eps(1e-5)
  673. ###### CONVERSION LOGIC ######
  674. def parse_args() -> argparse.Namespace:
  675. parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
  676. parser.add_argument(
  677. "--vocab-only", action="store_true",
  678. help="extract only the vocab",
  679. )
  680. parser.add_argument(
  681. "--outfile", type=Path,
  682. help="path to write to; default: based on input",
  683. )
  684. parser.add_argument(
  685. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  686. help="output format - use f32 for float32, f16 for float16",
  687. )
  688. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  689. parser.add_argument(
  690. "model", type=Path,
  691. help="directory containing model file",
  692. )
  693. return parser.parse_args()
  694. args = parse_args()
  695. dir_model = args.model
  696. if not dir_model.is_dir():
  697. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  698. sys.exit(1)
  699. ftype_map = {
  700. "f32": gguf.GGMLQuantizationType.F32,
  701. "f16": gguf.GGMLQuantizationType.F16,
  702. }
  703. if args.outfile is not None:
  704. fname_out = args.outfile
  705. else:
  706. # output in the same directory as the model by default
  707. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  708. print(f"Loading model: {dir_model.name}")
  709. hparams = Model.load_hparams(dir_model)
  710. model_class = Model.from_model_architecture(hparams["architectures"][0])
  711. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  712. print("Set model parameters")
  713. model_instance.set_gguf_parameters()
  714. print("Set model tokenizer")
  715. model_instance.set_vocab()
  716. if args.vocab_only:
  717. print(f"Exporting model vocab to '{fname_out}'")
  718. model_instance.write_vocab()
  719. else:
  720. print(f"Exporting model to '{fname_out}'")
  721. model_instance.write()
  722. print(f"Model successfully exported to '{fname_out}'")