convert-hf-to-gguf.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import contextlib
  5. import json
  6. import os
  7. import re
  8. import sys
  9. from enum import IntEnum
  10. from pathlib import Path
  11. from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
  12. import numpy as np
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. ###### MODEL DEFINITIONS ######
  20. class SentencePieceTokenTypes(IntEnum):
  21. NORMAL = 1
  22. UNKNOWN = 2
  23. CONTROL = 3
  24. USER_DEFINED = 4
  25. UNUSED = 5
  26. BYTE = 6
  27. class Model:
  28. def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
  29. self.dir_model = dir_model
  30. self.ftype = ftype
  31. self.fname_out = fname_out
  32. self.is_big_endian = is_big_endian
  33. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  34. self.is_safetensors = self._is_model_safetensors()
  35. self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
  36. self.part_names = self._get_part_names()
  37. self.hparams = Model.load_hparams(self.dir_model)
  38. self.model_arch = self._get_model_architecture()
  39. self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess)
  40. def set_vocab(self):
  41. self._set_vocab_gpt2()
  42. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  43. for part_name in self.part_names:
  44. print(f"gguf: loading model part '{part_name}'")
  45. ctx: ContextManager[Any]
  46. if self.is_safetensors:
  47. from safetensors import safe_open
  48. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  49. else:
  50. ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
  51. with ctx as model_part:
  52. for name in model_part.keys():
  53. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  54. yield name, data
  55. def set_gguf_parameters(self):
  56. self.gguf_writer.add_name(self.dir_model.name)
  57. self.gguf_writer.add_block_count(self.hparams.get(
  58. "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
  59. ))
  60. if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
  61. self.gguf_writer.add_context_length(n_ctx)
  62. if (n_embd := self.hparams.get("hidden_size")) is not None:
  63. self.gguf_writer.add_embedding_length(n_embd)
  64. if (n_ff := self.hparams.get("intermediate_size")) is not None:
  65. self.gguf_writer.add_feed_forward_length(n_ff)
  66. if (n_head := self.hparams.get("num_attention_head")) is not None:
  67. self.gguf_writer.add_head_count(n_head)
  68. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  69. def write_tensors(self):
  70. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
  71. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  72. for name, data_torch in self.get_tensors():
  73. # we don't need these
  74. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  75. continue
  76. old_dtype = data_torch.dtype
  77. # convert any unsupported data types to float32
  78. if data_torch.dtype not in (torch.float16, torch.float32):
  79. data_torch = data_torch.to(torch.float32)
  80. data = data_torch.squeeze().numpy()
  81. # map tensor names
  82. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  83. if new_name is None:
  84. print(f"Can not map tensor {name!r}")
  85. sys.exit()
  86. n_dims = len(data.shape)
  87. data_dtype = data.dtype
  88. # if f32 desired, convert any float16 to float32
  89. if self.ftype == 0 and data_dtype == np.float16:
  90. data = data.astype(np.float32)
  91. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  92. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  93. data = data.astype(np.float32)
  94. # if f16 desired, convert any float32 2-dim weight tensors to float16
  95. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  96. data = data.astype(np.float16)
  97. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  98. self.gguf_writer.add_tensor(new_name, data)
  99. def write(self):
  100. self.write_tensors()
  101. self.gguf_writer.write_header_to_file()
  102. self.gguf_writer.write_kv_data_to_file()
  103. self.gguf_writer.write_tensors_to_file()
  104. self.gguf_writer.close()
  105. def write_vocab(self):
  106. self.gguf_writer.write_header_to_file()
  107. self.gguf_writer.write_kv_data_to_file()
  108. self.gguf_writer.close()
  109. @staticmethod
  110. def count_model_parts(dir_model: Path, prefix: str) -> int:
  111. num_parts = 0
  112. for filename in os.listdir(dir_model):
  113. if filename.endswith(prefix):
  114. num_parts += 1
  115. return num_parts
  116. @staticmethod
  117. def load_hparams(dir_model):
  118. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  119. return json.load(f)
  120. @staticmethod
  121. def from_model_architecture(model_architecture):
  122. if model_architecture == "StableLMEpochForCausalLM":
  123. return StableLMModel
  124. if model_architecture == "GPTNeoXForCausalLM":
  125. return GPTNeoXModel
  126. if model_architecture == "BloomForCausalLM":
  127. return BloomModel
  128. if model_architecture == "MPTForCausalLM":
  129. return MPTModel
  130. if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  131. return BaichuanModel
  132. if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
  133. return FalconModel
  134. if model_architecture == "GPTBigCodeForCausalLM":
  135. return StarCoderModel
  136. if model_architecture == "GPTRefactForCausalLM":
  137. return RefactModel
  138. if model_architecture == "PersimmonForCausalLM":
  139. return PersimmonModel
  140. return Model
  141. def _is_model_safetensors(self) -> bool:
  142. return Model.count_model_parts(self.dir_model, ".safetensors") > 0
  143. def _get_part_names(self):
  144. if self.is_safetensors:
  145. if self.num_parts == 1: # there's only one .safetensors file
  146. return ("model.safetensors",)
  147. return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
  148. if self.num_parts == 1: # there's only one .bin file
  149. return ("pytorch_model.bin",)
  150. return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
  151. def _get_model_architecture(self) -> gguf.MODEL_ARCH:
  152. arch = self.hparams["architectures"][0]
  153. if arch == "GPTNeoXForCausalLM":
  154. return gguf.MODEL_ARCH.GPTNEOX
  155. if arch == "BloomForCausalLM":
  156. return gguf.MODEL_ARCH.BLOOM
  157. if arch == "MPTForCausalLM":
  158. return gguf.MODEL_ARCH.MPT
  159. if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
  160. return gguf.MODEL_ARCH.BAICHUAN
  161. if arch == "FalconForCausalLM":
  162. return gguf.MODEL_ARCH.FALCON
  163. if arch == "GPTBigCodeForCausalLM":
  164. return gguf.MODEL_ARCH.STARCODER
  165. if arch == "GPTRefactForCausalLM":
  166. return gguf.MODEL_ARCH.REFACT
  167. if arch == "PersimmonForCausalLM":
  168. return gguf.MODEL_ARCH.PERSIMMON
  169. raise NotImplementedError(f'Architecture "{arch}" not supported!')
  170. def _set_vocab_gpt2(self):
  171. dir_model = self.dir_model
  172. hparams = self.hparams
  173. tokens: list[bytearray] = []
  174. toktypes: list[int] = []
  175. from transformers import AutoTokenizer # type: ignore[attr-defined]
  176. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  177. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  178. assert max(tokenizer.vocab.values()) < vocab_size
  179. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  180. added_vocab = tokenizer.get_added_vocab()
  181. for i in range(vocab_size):
  182. if i not in reverse_vocab:
  183. pad_token = f"[PAD{i}]".encode('utf-8')
  184. tokens.append(bytearray(pad_token))
  185. toktypes.append(gguf.TokenType.USER_DEFINED)
  186. elif reverse_vocab[i] in added_vocab:
  187. tokens.append(reverse_vocab[i])
  188. if tokenizer.added_tokens_decoder[i].special:
  189. toktypes.append(gguf.TokenType.CONTROL)
  190. else:
  191. toktypes.append(gguf.TokenType.USER_DEFINED)
  192. else:
  193. tokens.append(reverse_vocab[i])
  194. toktypes.append(gguf.TokenType.NORMAL)
  195. self.gguf_writer.add_tokenizer_model("gpt2")
  196. self.gguf_writer.add_token_list(tokens)
  197. self.gguf_writer.add_token_types(toktypes)
  198. special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  199. special_vocab.add_to_gguf(self.gguf_writer)
  200. def _set_vocab_sentencepiece(self):
  201. from sentencepiece import SentencePieceProcessor
  202. tokenizer_path = self.dir_model / 'tokenizer.model'
  203. tokens: list[bytes] = []
  204. scores: list[float] = []
  205. toktypes: list[int] = []
  206. if not tokenizer_path.is_file():
  207. print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
  208. sys.exit(1)
  209. tokenizer = SentencePieceProcessor(str(tokenizer_path))
  210. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  211. for token_id in range(vocab_size):
  212. piece = tokenizer.id_to_piece(token_id)
  213. text = piece.encode("utf-8")
  214. score = tokenizer.get_score(token_id)
  215. toktype = SentencePieceTokenTypes.NORMAL
  216. if tokenizer.is_unknown(token_id):
  217. toktype = SentencePieceTokenTypes.UNKNOWN
  218. elif tokenizer.is_control(token_id):
  219. toktype = SentencePieceTokenTypes.CONTROL
  220. elif tokenizer.is_unused(token_id):
  221. toktype = SentencePieceTokenTypes.UNUSED
  222. elif tokenizer.is_byte(token_id):
  223. toktype = SentencePieceTokenTypes.BYTE
  224. tokens.append(text)
  225. scores.append(score)
  226. toktypes.append(toktype)
  227. added_tokens_file = self.dir_model / 'added_tokens.json'
  228. if added_tokens_file.is_file():
  229. with open(added_tokens_file, "r", encoding="utf-8") as f:
  230. added_tokens_json = json.load(f)
  231. for key in added_tokens_json:
  232. tokens.append(key.encode("utf-8"))
  233. scores.append(-1000.0)
  234. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  235. self.gguf_writer.add_tokenizer_model("llama")
  236. self.gguf_writer.add_token_list(tokens)
  237. self.gguf_writer.add_token_scores(scores)
  238. self.gguf_writer.add_token_types(toktypes)
  239. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  240. special_vocab.add_to_gguf(self.gguf_writer)
  241. class StableLMModel(Model):
  242. def set_gguf_parameters(self):
  243. super().set_gguf_parameters()
  244. self.gguf_writer.add_rope_dimension_count(
  245. int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  246. )
  247. self.gguf_writer.add_layer_norm_eps(1e-5)
  248. class GPTNeoXModel(Model):
  249. def set_gguf_parameters(self):
  250. block_count = self.hparams["num_hidden_layers"]
  251. self.gguf_writer.add_name(self.dir_model.name)
  252. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  253. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  254. self.gguf_writer.add_block_count(block_count)
  255. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  256. self.gguf_writer.add_rope_dimension_count(
  257. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  258. )
  259. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  260. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  261. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  262. class BloomModel(Model):
  263. def set_gguf_parameters(self):
  264. self.gguf_writer.add_name("Bloom")
  265. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  266. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  267. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  268. self.gguf_writer.add_embedding_length(n_embed)
  269. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  270. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  271. self.gguf_writer.add_head_count(n_head)
  272. self.gguf_writer.add_head_count_kv(n_head)
  273. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  274. self.gguf_writer.add_file_type(self.ftype)
  275. def write_tensors(self):
  276. block_count = self.hparams["n_layer"]
  277. tensors = dict(self.get_tensors())
  278. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  279. has_lm_head = True
  280. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  281. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  282. for name, data_torch in tensors.items():
  283. if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
  284. has_lm_head = False
  285. name = re.sub(r'transformer\.', '', name)
  286. old_dtype = data_torch.dtype
  287. # convert any unsupported data types to float32
  288. if data_torch.dtype not in (torch.float16, torch.float32):
  289. data_torch = data_torch.to(torch.float32)
  290. data = data_torch.squeeze().numpy()
  291. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  292. # Map bloom-style qkv_linear to gpt-style qkv_linear
  293. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  294. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  295. qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
  296. data = np.concatenate(
  297. (
  298. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  299. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  300. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  301. ),
  302. axis=0,
  303. )
  304. print("re-format attention.linear_qkv.weight")
  305. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  306. qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
  307. data = np.concatenate(
  308. (
  309. qkv_bias[:, 0, :].reshape((n_embed,)),
  310. qkv_bias[:, 1, :].reshape((n_embed,)),
  311. qkv_bias[:, 2, :].reshape((n_embed,)),
  312. ),
  313. axis=0,
  314. )
  315. print("re-format attention.linear_qkv.bias")
  316. # map tensor names
  317. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  318. if new_name is None:
  319. print(f"Can not map tensor {name!r}")
  320. sys.exit()
  321. n_dims = len(data.shape)
  322. data_dtype = data.dtype
  323. # if f32 desired, convert any float16 to float32
  324. if self.ftype == 0 and data_dtype == np.float16:
  325. data = data.astype(np.float32)
  326. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  327. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  328. data = data.astype(np.float32)
  329. # if f16 desired, convert any float32 2-dim weight tensors to float16
  330. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  331. data = data.astype(np.float16)
  332. print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  333. self.gguf_writer.add_tensor(new_name, data)
  334. if not has_lm_head and name == "word_embeddings.weight":
  335. self.gguf_writer.add_tensor("output.weight", data)
  336. print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
  337. class MPTModel(Model):
  338. def set_gguf_parameters(self):
  339. block_count = self.hparams["n_layers"]
  340. self.gguf_writer.add_name(self.dir_model.name)
  341. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  342. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  343. self.gguf_writer.add_block_count(block_count)
  344. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  345. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  346. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  347. self.gguf_writer.add_head_count_kv(kv_n_heads)
  348. self.gguf_writer.add_layer_norm_eps(1e-5)
  349. if self.hparams["attn_config"]["clip_qkv"] is not None:
  350. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  351. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  352. def write_tensors(self):
  353. block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
  354. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  355. for name, data_torch in self.get_tensors():
  356. # we don't need these
  357. if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
  358. continue
  359. old_dtype = data_torch.dtype
  360. # convert any unsupported data types to float32
  361. if data_torch.dtype not in (torch.float16, torch.float32):
  362. data_torch = data_torch.to(torch.float32)
  363. data = data_torch.squeeze().numpy()
  364. # map tensor names
  365. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  366. if new_name is None:
  367. print(f"Can not map tensor {name!r}")
  368. sys.exit()
  369. n_dims = len(data.shape)
  370. data_dtype = data.dtype
  371. # if f32 desired, convert any float16 to float32
  372. if self.ftype == 0 and data_dtype == np.float16:
  373. data = data.astype(np.float32)
  374. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  375. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  376. data = data.astype(np.float32)
  377. # if f16 desired, convert any float32 2-dim weight tensors to float16
  378. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  379. data = data.astype(np.float16)
  380. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  381. self.gguf_writer.add_tensor(new_name, data)
  382. # note: MPT output is tied to (same as) wte in original model;
  383. # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
  384. if new_name == "token_embd.weight":
  385. self.gguf_writer.add_tensor("output.weight", data)
  386. class BaichuanModel(Model):
  387. def set_vocab(self):
  388. self._set_vocab_sentencepiece()
  389. def set_gguf_parameters(self):
  390. block_count = self.hparams["num_hidden_layers"]
  391. head_count = self.hparams["num_attention_heads"]
  392. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  393. hf_repo = self.hparams.get("_name_or_path", "")
  394. ctx_length = 0
  395. if "max_sequence_length" in self.hparams:
  396. ctx_length = self.hparams["max_sequence_length"]
  397. elif "max_position_embeddings" in self.hparams:
  398. ctx_length = self.hparams["max_position_embeddings"]
  399. elif "model_max_length" in self.hparams:
  400. ctx_length = self.hparams["model_max_length"]
  401. else:
  402. print("gguf: can not find ctx length parameter.")
  403. sys.exit()
  404. self.gguf_writer.add_name(self.dir_model.name)
  405. self.gguf_writer.add_source_hf_repo(hf_repo)
  406. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  407. self.gguf_writer.add_context_length(ctx_length)
  408. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  409. self.gguf_writer.add_block_count(block_count)
  410. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  411. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  412. self.gguf_writer.add_head_count(head_count)
  413. self.gguf_writer.add_head_count_kv(head_count_kv)
  414. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  415. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  416. if self.hparams["rope_scaling"].get("type") == "linear":
  417. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  418. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  419. def write_tensors(self):
  420. # Collect tensors from generator object
  421. model_kv = dict(self.get_tensors())
  422. block_count = self.hparams["num_hidden_layers"]
  423. head_count = self.hparams["num_attention_heads"]
  424. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  425. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  426. for i in range(block_count):
  427. if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
  428. print(f"Unpacking and permuting layer {i}")
  429. model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
  430. self._reverse_hf_permute_part(w, 0, head_count, head_count)
  431. model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
  432. self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
  433. model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
  434. self._reverse_hf_part(w, 2)
  435. del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
  436. for name, data_torch in model_kv.items():
  437. # we don't need these
  438. if name.endswith(".rotary_emb.inv_freq"):
  439. continue
  440. old_dtype = data_torch.dtype
  441. # convert any unsupported data types to float32
  442. if data_torch.dtype not in (torch.float16, torch.float32):
  443. data_torch = data_torch.to(torch.float32)
  444. data = data_torch.squeeze().numpy()
  445. # map tensor names
  446. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  447. if new_name is None:
  448. print(f"Can not map tensor {name!r}")
  449. sys.exit()
  450. n_dims = len(data.shape)
  451. data_dtype = data.dtype
  452. # if f32 desired, convert any float16 to float32
  453. if self.ftype == 0 and data_dtype == np.float16:
  454. data = data.astype(np.float32)
  455. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  456. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  457. data = data.astype(np.float32)
  458. # if f16 desired, convert any float32 2-dim weight tensors to float16
  459. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  460. data = data.astype(np.float16)
  461. print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  462. self.gguf_writer.add_tensor(new_name, data)
  463. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  464. if n_kv_head is not None and n_head != n_kv_head:
  465. n_head //= n_kv_head
  466. return (
  467. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  468. .swapaxes(1, 2)
  469. .reshape(weights.shape)
  470. )
  471. def _reverse_hf_permute_part(
  472. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  473. ) -> Tensor:
  474. r = weights.shape[0] // 3
  475. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  476. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  477. r = weights.shape[0] // 3
  478. return weights[r * n_part:r * n_part + r, ...]
  479. class FalconModel(Model):
  480. def set_gguf_parameters(self):
  481. block_count = self.hparams.get("num_hidden_layers")
  482. if block_count is None:
  483. block_count = self.hparams["n_layer"] # old name
  484. n_head = self.hparams.get("num_attention_heads")
  485. if n_head is None:
  486. n_head = self.hparams["n_head"] # old name
  487. n_head_kv = self.hparams.get("num_kv_heads")
  488. if n_head_kv is None:
  489. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  490. self.gguf_writer.add_name("Falcon")
  491. self.gguf_writer.add_context_length(2048) # not in config.json
  492. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  493. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  494. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  495. self.gguf_writer.add_block_count(block_count)
  496. self.gguf_writer.add_head_count(n_head)
  497. self.gguf_writer.add_head_count_kv(n_head_kv)
  498. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  499. self.gguf_writer.add_file_type(self.ftype)
  500. def write_tensors(self):
  501. block_count = self.hparams.get("num_hidden_layers")
  502. if block_count is None:
  503. block_count = self.hparams["n_layer"] # old name
  504. n_head = self.hparams.get("num_attention_heads")
  505. if n_head is None:
  506. n_head = self.hparams["n_head"] # old name
  507. n_head_kv = self.hparams.get("num_kv_heads")
  508. if n_head_kv is None:
  509. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  510. head_dim = self.hparams["hidden_size"] // n_head
  511. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  512. for name, data_torch in self.get_tensors():
  513. old_dtype = data_torch.dtype
  514. # convert any unsupported data types to float32
  515. if data_torch.dtype not in (torch.float16, torch.float32):
  516. data_torch = data_torch.to(torch.float32)
  517. # QKV tensor transform
  518. # The original query_key_value tensor contains n_head_kv "kv groups",
  519. # each consisting of n_head/n_head_kv query weights followed by one key
  520. # and one value weight (shared by all query heads in the kv group).
  521. # This layout makes it a big pain to work with in GGML.
  522. # So we rearrange them here,, so that we have n_head query weights
  523. # followed by n_head_kv key weights followed by n_head_kv value weights,
  524. # in contiguous fashion.
  525. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  526. if "query_key_value" in name:
  527. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  528. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  529. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  530. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  531. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  532. data = data_torch.squeeze().numpy()
  533. # map tensor names
  534. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  535. if new_name is None:
  536. print(f"Can not map tensor {name!r}")
  537. sys.exit()
  538. n_dims = len(data.shape)
  539. data_dtype = data.dtype
  540. # if f32 desired, convert any float16 to float32
  541. if self.ftype == 0 and data_dtype == np.float16:
  542. data = data.astype(np.float32)
  543. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  544. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  545. data = data.astype(np.float32)
  546. # if f16 desired, convert any float32 2-dim weight tensors to float16
  547. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  548. data = data.astype(np.float16)
  549. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  550. self.gguf_writer.add_tensor(new_name, data)
  551. class StarCoderModel(Model):
  552. def set_gguf_parameters(self):
  553. block_count = self.hparams["n_layer"]
  554. self.gguf_writer.add_name("StarCoder")
  555. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  556. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  557. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  558. self.gguf_writer.add_block_count(block_count)
  559. self.gguf_writer.add_head_count(self.hparams["n_head"])
  560. self.gguf_writer.add_head_count_kv(1)
  561. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  562. self.gguf_writer.add_file_type(self.ftype)
  563. class RefactModel(Model):
  564. def set_gguf_parameters(self):
  565. hidden_dim = self.hparams["n_embd"]
  566. inner_dim = 4 * hidden_dim
  567. hidden_dim = int(2 * inner_dim / 3)
  568. multiple_of = 256
  569. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  570. block_count = self.hparams["n_layer"]
  571. self.gguf_writer.add_name("Refact")
  572. # refact uses Alibi. So this is from config.json which might be used by training.
  573. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  574. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  575. self.gguf_writer.add_feed_forward_length(ff_dim)
  576. self.gguf_writer.add_block_count(block_count)
  577. self.gguf_writer.add_head_count(self.hparams["n_head"])
  578. self.gguf_writer.add_head_count_kv(1)
  579. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  580. self.gguf_writer.add_file_type(self.ftype)
  581. def write_tensors(self):
  582. hidden_dim = self.hparams["n_embd"]
  583. inner_dim = 4 * hidden_dim
  584. hidden_dim = int(2 * inner_dim / 3)
  585. multiple_of = 256
  586. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  587. n_head = self.hparams["n_head"]
  588. n_head_kv = 1
  589. head_dim = self.hparams["n_embd"] // n_head
  590. block_count = self.hparams["n_layer"]
  591. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  592. tensors = dict(self.get_tensors())
  593. for i in range(block_count):
  594. if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
  595. tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
  596. tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
  597. del tensors[f"transformer.h.{i}.attn.kv.weight"]
  598. if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
  599. tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
  600. del tensors[f"transformer.h.{i}.attn.q.weight"]
  601. if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
  602. tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
  603. tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
  604. del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
  605. for name, data_torch in tensors.items():
  606. old_dtype = data_torch.dtype
  607. # convert any unsupported data types to float32
  608. if data_torch.dtype not in (torch.float16, torch.float32):
  609. data_torch = data_torch.to(torch.float32)
  610. data = data_torch.squeeze().numpy()
  611. # map tensor names
  612. new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
  613. if new_name is None:
  614. print(f"Can not map tensor {name!r}")
  615. sys.exit()
  616. n_dims = len(data.shape)
  617. data_dtype = data.dtype
  618. # if f32 desired, convert any float16 to float32
  619. if self.ftype == 0 and data_dtype == np.float16:
  620. data = data.astype(np.float32)
  621. # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
  622. if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
  623. data = data.astype(np.float32)
  624. # if f16 desired, convert any float32 2-dim weight tensors to float16
  625. if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
  626. data = data.astype(np.float16)
  627. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  628. self.gguf_writer.add_tensor(new_name, data)
  629. class PersimmonModel(Model):
  630. def set_gguf_parameters(self):
  631. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  632. head_count = self.hparams["num_attention_heads"]
  633. head_count_kv = head_count
  634. hidden_size = self.hparams["hidden_size"]
  635. self.gguf_writer.add_name('persimmon-8b-chat')
  636. self.gguf_writer.add_embedding_length(hidden_size)
  637. self.gguf_writer.add_block_count(block_count)
  638. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  639. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  640. self.gguf_writer.add_head_count(head_count)
  641. self.gguf_writer.add_head_count_kv(head_count_kv)
  642. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  643. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  644. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  645. def set_vocab(self):
  646. self._set_vocab_sentencepiece()
  647. # self.gguf_writer.add_bos_token_id(71013)
  648. # self.gguf_writer.add_eos_token_id(71013)
  649. def write_tensors(self):
  650. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  651. tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
  652. for name, data_torch in self.get_tensors():
  653. if name.endswith(".self_attention.rotary_emb.inv_freq"):
  654. continue
  655. old_dtype = data_torch.dtype
  656. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  657. data = data_torch.to(torch.float32).squeeze().numpy()
  658. new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
  659. if new_name is None:
  660. print(f"Can not map tensor {name!r}")
  661. sys.exit()
  662. n_dims = len(data.shape)
  663. print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
  664. self.gguf_writer.add_tensor(new_name, data)
  665. ###### CONVERSION LOGIC ######
  666. def parse_args() -> argparse.Namespace:
  667. parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
  668. parser.add_argument(
  669. "--vocab-only", action="store_true",
  670. help="extract only the vocab",
  671. )
  672. parser.add_argument(
  673. "--outfile", type=Path,
  674. help="path to write to; default: based on input",
  675. )
  676. parser.add_argument(
  677. "--outtype", type=str, choices=["f32", "f16"], default="f16",
  678. help="output format - use f32 for float32, f16 for float16",
  679. )
  680. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  681. parser.add_argument(
  682. "model", type=Path,
  683. help="directory containing model file",
  684. )
  685. return parser.parse_args()
  686. args = parse_args()
  687. dir_model = args.model
  688. if not dir_model.is_dir():
  689. print(f'Error: {args.model} is not a directory', file=sys.stderr)
  690. sys.exit(1)
  691. ftype_map = {
  692. "f32": gguf.GGMLQuantizationType.F32,
  693. "f16": gguf.GGMLQuantizationType.F16,
  694. }
  695. if args.outfile is not None:
  696. fname_out = args.outfile
  697. else:
  698. # output in the same directory as the model by default
  699. fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
  700. print(f"Loading model: {dir_model.name}")
  701. hparams = Model.load_hparams(dir_model)
  702. model_class = Model.from_model_architecture(hparams["architectures"][0])
  703. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
  704. print("Set model parameters")
  705. model_instance.set_gguf_parameters()
  706. print("Set model tokenizer")
  707. model_instance.set_vocab()
  708. if args.vocab_only:
  709. print(f"Exporting model vocab to '{fname_out}'")
  710. model_instance.write_vocab()
  711. else:
  712. print(f"Exporting model to '{fname_out}'")
  713. model_instance.write()
  714. print(f"Model successfully exported to '{fname_out}'")