convert.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import concurrent.futures
  5. import enum
  6. import faulthandler
  7. import functools
  8. import itertools
  9. import json
  10. import math
  11. import mmap
  12. import os
  13. import pickle
  14. import re
  15. import signal
  16. import struct
  17. import sys
  18. import time
  19. import zipfile
  20. from abc import ABCMeta, abstractmethod
  21. from collections import OrderedDict
  22. from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  23. from dataclasses import dataclass
  24. from pathlib import Path
  25. from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
  26. import numpy as np
  27. from sentencepiece import SentencePieceProcessor
  28. if 'NO_LOCAL_GGUF' not in os.environ:
  29. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  30. import gguf
  31. if TYPE_CHECKING:
  32. from typing import TypeAlias
  33. if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
  34. faulthandler.register(signal.SIGUSR1)
  35. NDArray: TypeAlias = 'np.ndarray[Any, Any]'
  36. ARCH = gguf.MODEL_ARCH.LLAMA
  37. DEFAULT_CONCURRENCY = 8
  38. #
  39. # data types
  40. #
  41. @dataclass(frozen=True)
  42. class DataType:
  43. name: str
  44. dtype: np.dtype[Any]
  45. valid_conversions: list[str]
  46. def elements_to_bytes(self, n_elements: int) -> int:
  47. return n_elements * self.dtype.itemsize
  48. @dataclass(frozen=True)
  49. class UnquantizedDataType(DataType):
  50. pass
  51. DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
  52. DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
  53. DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
  54. DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
  55. @dataclass(frozen=True)
  56. class QuantizedDataType(DataType):
  57. block_size: int
  58. quantized_dtype: np.dtype[Any]
  59. ggml_type: gguf.GGMLQuantizationType
  60. def quantize(self, arr: NDArray) -> NDArray:
  61. raise NotImplementedError(f'Quantization for {self.name} not implemented')
  62. def elements_to_bytes(self, n_elements: int) -> int:
  63. assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
  64. return self.quantized_dtype.itemsize * (n_elements // self.block_size)
  65. @dataclass(frozen=True)
  66. class Q8_0QuantizedDataType(QuantizedDataType):
  67. # Mini Q8_0 quantization in Python!
  68. def quantize(self, arr: NDArray) -> NDArray:
  69. assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
  70. assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
  71. n_blocks = arr.size // self.block_size
  72. blocks = arr.reshape((n_blocks, self.block_size))
  73. # Much faster implementation of block quantization contributed by @Cebtenzzre
  74. def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]:
  75. d = abs(blocks).max(axis = 1) / np.float32(127)
  76. with np.errstate(divide = 'ignore'):
  77. qs = (blocks / d[:, None]).round()
  78. qs[d == 0] = 0
  79. yield from zip(d, qs)
  80. return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
  81. DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
  82. dtype = np.dtype(np.float32), valid_conversions = [],
  83. ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
  84. quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
  85. # Quantized types skipped here because they may also map to np.float32
  86. NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {}
  87. for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
  88. if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
  89. raise ValueError(f'Invalid duplicate data type {dt}')
  90. NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
  91. SAFETENSORS_DATA_TYPES: dict[str, DataType] = {
  92. 'BF16': DT_BF16,
  93. 'F16': DT_F16,
  94. 'F32': DT_F32,
  95. 'I32': DT_I32,
  96. }
  97. # TODO: match this with `llama_ftype`
  98. # TODO: rename to LLAMAFileType
  99. # TODO: move to `gguf.py`
  100. class GGMLFileType(enum.IntEnum):
  101. AllF32 = 0
  102. MostlyF16 = 1 # except 1d tensors
  103. MostlyQ8_0 = 7 # except 1d tensors
  104. def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
  105. dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
  106. if dt is None:
  107. raise ValueError(self)
  108. # 1D tensors are always F32.
  109. return dt if len(tensor.shape) > 1 else DT_F32
  110. GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
  111. GGMLFileType.AllF32 : DT_F32,
  112. GGMLFileType.MostlyF16 : DT_F16,
  113. GGMLFileType.MostlyQ8_0: DT_Q8_0,
  114. }
  115. #
  116. # hparams loading
  117. #
  118. @dataclass
  119. class Params:
  120. n_vocab: int
  121. n_embd: int
  122. n_layer: int
  123. n_ctx: int
  124. n_ff: int
  125. n_head: int
  126. n_head_kv: int
  127. n_experts: int | None = None
  128. n_experts_used: int | None = None
  129. f_norm_eps: float | None = None
  130. rope_scaling_type: gguf.RopeScalingType | None = None
  131. f_rope_freq_base: float | None = None
  132. f_rope_scale: float | None = None
  133. n_orig_ctx: int | None = None
  134. rope_finetuned: bool | None = None
  135. ftype: GGMLFileType | None = None
  136. # path to the directory containing the model files
  137. path_model: Path | None = None
  138. @staticmethod
  139. def guessed(model: LazyModel) -> Params:
  140. # try transformer naming first
  141. n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
  142. # try transformer naming first
  143. if "model.layers.0.self_attn.q_proj.weight" in model:
  144. n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model)
  145. elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming
  146. n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model)
  147. else:
  148. n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
  149. if n_layer < 1:
  150. raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
  151. "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
  152. n_head = n_embd // 128 # guessed
  153. n_mult = 256 # guessed
  154. # TODO: verify this
  155. n_ff = int(2 * (4 * n_embd) / 3)
  156. n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
  157. return Params(
  158. n_vocab = n_vocab,
  159. n_embd = n_embd,
  160. n_layer = n_layer,
  161. n_ctx = -1,
  162. n_ff = n_ff,
  163. n_head = n_head,
  164. n_head_kv = n_head,
  165. f_norm_eps = 1e-5,
  166. )
  167. @staticmethod
  168. def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
  169. config = json.load(open(config_path))
  170. rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
  171. rope_scaling = config.get("rope_scaling")
  172. if rope_scaling is not None and (typ := rope_scaling.get("type")):
  173. rope_factor = rope_scaling.get("factor")
  174. f_rope_scale = rope_factor
  175. if typ == "linear":
  176. rope_scaling_type = gguf.RopeScalingType.LINEAR
  177. elif typ == "yarn":
  178. rope_scaling_type = gguf.RopeScalingType.YARN
  179. n_orig_ctx = rope_scaling['original_max_position_embeddings']
  180. rope_finetuned = rope_scaling['finetuned']
  181. else:
  182. raise NotImplementedError(f'Unknown rope scaling type: {typ}')
  183. if "max_sequence_length" in config:
  184. n_ctx = config["max_sequence_length"]
  185. elif "max_position_embeddings" in config:
  186. n_ctx = config["max_position_embeddings"]
  187. else:
  188. raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
  189. "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
  190. n_experts = None
  191. n_experts_used = None
  192. if "num_local_experts" in config:
  193. n_experts = config["num_local_experts"]
  194. n_experts_used = config["num_experts_per_tok"]
  195. return Params(
  196. n_vocab = config["vocab_size"],
  197. n_embd = config["hidden_size"],
  198. n_layer = config["num_hidden_layers"],
  199. n_ctx = n_ctx,
  200. n_ff = config["intermediate_size"],
  201. n_head = (n_head := config["num_attention_heads"]),
  202. n_head_kv = config.get("num_key_value_heads", n_head),
  203. n_experts = n_experts,
  204. n_experts_used = n_experts_used,
  205. f_norm_eps = config["rms_norm_eps"],
  206. f_rope_freq_base = config.get("rope_theta"),
  207. rope_scaling_type = rope_scaling_type,
  208. f_rope_scale = f_rope_scale,
  209. n_orig_ctx = n_orig_ctx,
  210. rope_finetuned = rope_finetuned,
  211. )
  212. # LLaMA v2 70B params.json
  213. # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
  214. @staticmethod
  215. def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
  216. config = json.load(open(config_path))
  217. n_experts = None
  218. n_experts_used = None
  219. f_rope_freq_base = None
  220. # hack to determine LLaMA v1 vs v2 vs CodeLlama
  221. if config.get("moe"):
  222. # Mixtral
  223. n_ctx = 32768
  224. elif config.get("rope_theta") == 1000000:
  225. # CodeLlama
  226. n_ctx = 16384
  227. elif config["norm_eps"] == 1e-05:
  228. # LLaMA v2
  229. n_ctx = 4096
  230. else:
  231. # LLaMA v1
  232. n_ctx = 2048
  233. if "layers.0.feed_forward.w1.weight" in model:
  234. n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
  235. if config.get("moe"):
  236. n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
  237. n_experts = config["moe"]["num_experts"]
  238. n_experts_used = config["moe"]["num_experts_per_tok"]
  239. f_rope_freq_base = 1e6
  240. return Params(
  241. n_vocab = model["tok_embeddings.weight"].shape[0],
  242. n_embd = config["dim"],
  243. n_layer = config["n_layers"],
  244. n_ctx = n_ctx,
  245. n_ff = n_ff,
  246. n_head = (n_head := config["n_heads"]),
  247. n_head_kv = config.get("n_kv_heads", n_head),
  248. n_experts = n_experts,
  249. n_experts_used = n_experts_used,
  250. f_norm_eps = config["norm_eps"],
  251. f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
  252. )
  253. @staticmethod
  254. def load(model_plus: ModelPlus) -> Params:
  255. hf_config_path = model_plus.paths[0].parent / "config.json"
  256. orig_config_path = model_plus.paths[0].parent / "params.json"
  257. if hf_config_path.exists():
  258. params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
  259. elif orig_config_path.exists():
  260. params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
  261. elif model_plus.format != 'none':
  262. params = Params.guessed(model_plus.model)
  263. else:
  264. raise ValueError('Cannot guess params when model format is none')
  265. params.path_model = model_plus.paths[0].parent
  266. return params
  267. class VocabLoader:
  268. def __init__(self, params: Params, fname_tokenizer: Path) -> None:
  269. try:
  270. from transformers import AutoTokenizer
  271. except ImportError as e:
  272. raise ImportError(
  273. "To use VocabLoader, please install the `transformers` package. "
  274. "You can install it with `pip install transformers`."
  275. ) from e
  276. try:
  277. self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
  278. except ValueError:
  279. self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
  280. self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
  281. for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]):
  282. if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size:
  283. continue
  284. self.added_tokens_dict[tok] = tokidx
  285. self.unk_token_id: int = self.tokenizer.unk_token_id
  286. self.specials: dict[str, int] = {
  287. tok: self.tokenizer.get_vocab()[tok]
  288. for tok in self.tokenizer.all_special_tokens
  289. }
  290. self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
  291. self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
  292. self.vocab_size_base: int = self.tokenizer.vocab_size
  293. self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
  294. self.fname_tokenizer: Path = fname_tokenizer
  295. vocab_file = "tokenizer.model"
  296. path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
  297. if path_candidate is not None:
  298. self.spm = SentencePieceProcessor(str(path_candidate))
  299. print(self.spm.vocab_size(), self.vocab_size_base)
  300. else:
  301. self.spm = None
  302. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  303. added_tokens_ids = set(self.added_tokens_dict.values())
  304. for i in range(self.vocab_size_base):
  305. if i in added_tokens_ids:
  306. continue
  307. text = self.reverse_vocab[i].encode("utf-8")
  308. yield text, self.get_token_score(i), self.get_token_type(i)
  309. def get_token_type(self, token_id: int) -> gguf.TokenType:
  310. toktype = gguf.TokenType.NORMAL
  311. if self.spm is not None and token_id < self.spm.vocab_size():
  312. if self.spm.is_unknown(token_id):
  313. toktype = gguf.TokenType.UNKNOWN
  314. if self.spm.is_control(token_id):
  315. toktype = gguf.TokenType.CONTROL
  316. if self.spm.is_unused(token_id):
  317. toktype = gguf.TokenType.UNUSED
  318. if self.spm.is_byte(token_id):
  319. toktype = gguf.TokenType.BYTE
  320. else:
  321. token = self.reverse_vocab[token_id]
  322. if token_id == self.unk_token_id:
  323. toktype = gguf.TokenType.UNKNOWN
  324. elif token_id in self.special_ids:
  325. toktype = gguf.TokenType.CONTROL
  326. elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
  327. toktype = gguf.TokenType.BYTE
  328. return toktype
  329. def get_token_score(self, token_id: int) -> float:
  330. if self.spm is not None and token_id < self.spm.vocab_size():
  331. return cast(float, self.spm.get_score(token_id))
  332. return 0.0
  333. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  334. for text in self.added_tokens_dict:
  335. if text in self.specials:
  336. toktype = self.get_token_type(self.specials[text])
  337. score = self.get_token_score(self.specials[text])
  338. else:
  339. toktype = gguf.TokenType.USER_DEFINED
  340. score = -1000.0
  341. yield text.encode("utf-8"), score, toktype
  342. def has_newline_token(self) -> bool:
  343. return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
  344. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  345. yield from self.hf_tokens()
  346. yield from self.added_tokens()
  347. def get_vocab_type(self) -> str:
  348. path_candidates = []
  349. vocab_file = "tokenizer.model"
  350. path_candidates.append(vocab_file)
  351. path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
  352. if path_candidate is not None:
  353. return "llama"
  354. vocab_file = "vocab.json"
  355. path_candidates.append(vocab_file)
  356. path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
  357. if path_candidate is not None:
  358. return "gpt2"
  359. vocab_file = "tokenizer.json"
  360. path_candidates.append(vocab_file)
  361. path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
  362. if path_candidate:
  363. if not self.has_newline_token():
  364. return "gpt2"
  365. return "llama"
  366. raise FileNotFoundError(
  367. f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; "
  368. "if it's in another directory, pass the directory as --vocab-dir"
  369. )
  370. def __repr__(self) -> str:
  371. return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
  372. Vocab: TypeAlias = 'VocabLoader'
  373. #
  374. # data loading
  375. # TODO: reuse (probably move to gguf.py?)
  376. #
  377. def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
  378. # print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
  379. if n_head_kv is not None and n_head != n_head_kv:
  380. n_head = n_head_kv
  381. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  382. .swapaxes(1, 2)
  383. .reshape(weights.shape))
  384. class Tensor(metaclass=ABCMeta):
  385. data_type: DataType
  386. @abstractmethod
  387. def astype(self, data_type: DataType) -> Tensor: ...
  388. @abstractmethod
  389. def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
  390. @abstractmethod
  391. def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
  392. @abstractmethod
  393. def part(self, n_part: int) -> UnquantizedTensor: ...
  394. @abstractmethod
  395. def to_ggml(self) -> GGMLCompatibleTensor: ...
  396. def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
  397. assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
  398. fp32_arr = bf16_arr.astype(np.uint32) << 16
  399. return fp32_arr.view(np.float32)
  400. class UnquantizedTensor(Tensor):
  401. def __init__(self, ndarray: NDArray) -> None:
  402. assert isinstance(ndarray, np.ndarray)
  403. self.ndarray = ndarray
  404. self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
  405. def astype(self, data_type: DataType) -> Tensor:
  406. dtype = data_type.dtype
  407. if self.data_type == DT_BF16:
  408. self.ndarray = bf16_to_fp32(self.ndarray)
  409. return UnquantizedTensor(self.ndarray.astype(dtype))
  410. def to_ggml(self) -> UnquantizedTensor:
  411. return self
  412. def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
  413. r = self.ndarray.shape[0] // 3
  414. return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
  415. def part(self, n_part: int) -> UnquantizedTensor:
  416. r = self.ndarray.shape[0] // 3
  417. return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
  418. def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor:
  419. return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
  420. def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray:
  421. tensor = lazy_tensor.load()
  422. assert isinstance(tensor, UnquantizedTensor)
  423. # double-check:
  424. actual_shape = list(tensor.ndarray.shape)
  425. assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape)
  426. if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype:
  427. if convert:
  428. tensor.ndarray = tensor.ndarray.astype(expected_dtype)
  429. else:
  430. raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}')
  431. return tensor.ndarray
  432. GGMLCompatibleTensor = UnquantizedTensor
  433. @dataclass
  434. class LazyTensor:
  435. _load: Callable[[], Tensor]
  436. shape: list[int]
  437. data_type: DataType
  438. description: str
  439. def load(self) -> Tensor:
  440. ret = self._load()
  441. # Should be okay if it maps to the same numpy type?
  442. assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
  443. (self.data_type, ret.data_type, self.description)
  444. return ret
  445. def astype(self, data_type: DataType) -> LazyTensor:
  446. self.validate_conversion_to(data_type)
  447. def load() -> Tensor:
  448. return self.load().astype(data_type)
  449. return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
  450. def validate_conversion_to(self, data_type: DataType) -> None:
  451. if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
  452. raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
  453. LazyModel: TypeAlias = 'dict[str, LazyTensor]'
  454. @dataclass
  455. class ModelPlus:
  456. model: LazyModel
  457. paths: list[Path] # Where this was read from.
  458. format: Literal['ggml', 'torch', 'safetensors', 'none']
  459. vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
  460. def merge_sharded(models: list[LazyModel]) -> LazyModel:
  461. # Original LLaMA models have each file contain one part of each tensor.
  462. # Use a dict instead of a set to preserve order.
  463. names = {name: None for model in models for name in model}
  464. def convert(name: str) -> LazyTensor:
  465. lazy_tensors: list[LazyTensor] = [model[name] for model in models]
  466. if len(lazy_tensors) == 1:
  467. # only one file; don't go through this procedure since there might
  468. # be quantized tensors
  469. return lazy_tensors[0]
  470. if len(lazy_tensors[0].shape) == 1:
  471. # the tensor is just duplicated in every file
  472. return lazy_tensors[0]
  473. if name.startswith('tok_embeddings.') or \
  474. name.endswith('.attention.wo.weight') or \
  475. name.endswith('.feed_forward.w2.weight'):
  476. # split by columns
  477. axis = 1
  478. else:
  479. # split by rows
  480. axis = 0
  481. concatenated_shape = list(lazy_tensors[0].shape)
  482. concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors)
  483. def load() -> UnquantizedTensor:
  484. ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
  485. concatenated: NDArray = np.concatenate(ndarrays, axis=axis)
  486. return UnquantizedTensor(concatenated)
  487. description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
  488. return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
  489. return {name: convert(name) for name in names}
  490. def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
  491. formats = set(mp.format for mp in models_plus)
  492. assert len(formats) == 1, "different formats?"
  493. format = formats.pop()
  494. paths = [path for mp in models_plus for path in mp.paths]
  495. # Use the first non-None vocab, if any.
  496. try:
  497. vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None)
  498. except StopIteration:
  499. vocab = None
  500. if any("model.embed_tokens.weight" in mp.model for mp in models_plus):
  501. # Transformers models put different tensors in different files, but
  502. # don't split individual tensors between files.
  503. model: LazyModel = {}
  504. for mp in models_plus:
  505. model.update(mp.model)
  506. else:
  507. model = merge_sharded([mp.model for mp in models_plus])
  508. return ModelPlus(model, paths, format, vocab)
  509. def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
  510. def load() -> Tensor:
  511. return lazy_tensor.load().permute(n_head, n_head_kv)
  512. return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
  513. def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor:
  514. def load() -> Tensor:
  515. return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv)
  516. s = lazy_tensor.shape.copy()
  517. s[0] = s[0] // 3
  518. return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
  519. def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
  520. def load() -> Tensor:
  521. return lazy_tensor.load().part(n_part)
  522. s = lazy_tensor.shape.copy()
  523. s[0] = s[0] // 3
  524. return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
  525. # Functionality that simulates `torch.load` but where individual tensors are
  526. # only loaded into memory on demand, not all at once.
  527. # PyTorch can't do this natively as of time of writing:
  528. # - https://github.com/pytorch/pytorch/issues/64327
  529. # This allows us to de-shard without multiplying RAM usage, and also
  530. # conveniently drops the PyTorch dependency (though we still need numpy).
  531. @dataclass
  532. class LazyStorageKind:
  533. data_type: DataType
  534. @dataclass
  535. class LazyStorage:
  536. load: Callable[[int, int], NDArray]
  537. kind: LazyStorageKind
  538. description: str
  539. class LazyUnpickler(pickle.Unpickler):
  540. def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
  541. super().__init__(fp)
  542. self.data_base_path = data_base_path
  543. self.zip_file = zip_file
  544. def persistent_load(self, pid: Any) -> Any:
  545. assert pid[0] == 'storage'
  546. assert isinstance(pid[1], LazyStorageKind)
  547. data_type = pid[1].data_type
  548. filename_stem = pid[2]
  549. filename = f'{self.data_base_path}/{filename_stem}'
  550. info = self.zip_file.getinfo(filename)
  551. def load(offset: int, elm_count: int) -> NDArray:
  552. dtype = data_type.dtype
  553. fp = self.zip_file.open(info)
  554. fp.seek(offset * dtype.itemsize)
  555. size = elm_count * dtype.itemsize
  556. data = fp.read(size)
  557. assert len(data) == size
  558. return np.frombuffer(data, dtype)
  559. description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
  560. return LazyStorage(load=load, kind=pid[1], description=description)
  561. @staticmethod
  562. def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
  563. requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
  564. assert isinstance(storage, LazyStorage)
  565. def load() -> UnquantizedTensor:
  566. elm_count = stride[0] * size[0]
  567. return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size))
  568. description = f'pickled storage_offset={storage_offset} in {storage.description}'
  569. return LazyTensor(load, list(size), storage.kind.data_type, description)
  570. @staticmethod
  571. def rebuild_from_type_v2(func, new_type, args, state):
  572. return func(*args)
  573. CLASSES: dict[tuple[str, str], Any] = {
  574. # getattr used here as a workaround for mypy not being smart enough to determine
  575. # the staticmethods have a __func__ attribute.
  576. ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
  577. ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
  578. ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
  579. ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
  580. ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
  581. ('torch', 'IntStorage'): LazyStorageKind(DT_I32),
  582. ('torch', 'Tensor'): LazyTensor,
  583. }
  584. def find_class(self, module: str, name: str) -> Any:
  585. if not module.startswith('torch'):
  586. return super().find_class(module, name)
  587. return self.CLASSES[(module, name)]
  588. def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
  589. zf = zipfile.ZipFile(outer_fp)
  590. pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
  591. assert len(pickle_paths) == 1, pickle_paths
  592. pickle_fp = zf.open(pickle_paths[0], 'r')
  593. unpickler = LazyUnpickler(pickle_fp,
  594. data_base_path=pickle_paths[0][:-4],
  595. zip_file=zf)
  596. model = unpickler.load()
  597. if 'model' in model: model = model['model']
  598. as_dict = dict(model.items())
  599. return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
  600. def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
  601. header_size, = struct.unpack('<Q', fp.read(8))
  602. header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size))
  603. # Use mmap for the actual data to avoid race conditions with the file offset.
  604. mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
  605. byte_buf = mapped[8 + header_size:]
  606. def convert(info: dict[str, Any]) -> LazyTensor:
  607. data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
  608. numpy_dtype = data_type.dtype
  609. shape: list[int] = info['shape']
  610. begin, end = info['data_offsets']
  611. assert 0 <= begin <= end <= len(byte_buf)
  612. assert end - begin == math.prod(shape) * numpy_dtype.itemsize
  613. buf = byte_buf[begin:end]
  614. def load() -> UnquantizedTensor:
  615. return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
  616. description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
  617. return LazyTensor(load, shape, data_type, description)
  618. model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
  619. return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
  620. def must_read(fp: IO[bytes], length: int) -> bytes:
  621. ret = fp.read(length)
  622. if len(ret) < length:
  623. raise Exception("unexpectedly reached end of file")
  624. return ret
  625. @functools.lru_cache(maxsize=None)
  626. def lazy_load_file(path: Path) -> ModelPlus:
  627. fp = open(path, 'rb')
  628. first8 = fp.read(8)
  629. fp.seek(0)
  630. if first8[:2] == b'PK':
  631. # A zip file, i.e. PyTorch format
  632. return lazy_load_torch_file(fp, path)
  633. elif struct.unpack('<Q', first8)[0] < 16 * 1024 * 1024:
  634. # Probably safetensors
  635. return lazy_load_safetensors_file(fp, path)
  636. else:
  637. raise ValueError(f"unknown format: {path}")
  638. In = TypeVar('In')
  639. Out = TypeVar('Out')
  640. def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]:
  641. '''Parallel map, but with backpressure. If the caller doesn't call `next`
  642. fast enough, this will stop calling `func` at some point rather than
  643. letting results pile up in memory. Specifically, there is a max of one
  644. output value buffered per thread.'''
  645. if concurrency < 2:
  646. yield from map(func, iterable)
  647. # Not reached.
  648. iterable = iter(iterable)
  649. executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor]
  650. if use_processpool_executor:
  651. executor_class = ProcessPoolExecutor
  652. else:
  653. executor_class = ThreadPoolExecutor
  654. with executor_class(max_workers = max_workers) as executor:
  655. futures: list[concurrent.futures.Future[Out]] = []
  656. done = False
  657. for _ in range(concurrency):
  658. try:
  659. futures.append(executor.submit(func, next(iterable)))
  660. except StopIteration:
  661. done = True
  662. break
  663. while futures:
  664. result = futures.pop(0).result()
  665. while not done and len(futures) < concurrency:
  666. try:
  667. futures.append(executor.submit(func, next(iterable)))
  668. except StopIteration:
  669. done = True
  670. break
  671. yield result
  672. def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
  673. if params.n_vocab != vocab.vocab_size:
  674. if params.n_vocab == vocab.vocab_size:
  675. print("Ignoring added_tokens.json since model matches vocab size without it.")
  676. vocab.added_tokens_dict = OrderedDict()
  677. vocab.vocab_size = vocab.vocab_size
  678. return
  679. if pad_vocab and params.n_vocab > vocab.vocab_size:
  680. pad_count = params.n_vocab - vocab.vocab_size
  681. print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
  682. for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
  683. vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
  684. vocab.vocab_size = params.n_vocab
  685. return
  686. msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
  687. msg += f" has {vocab.vocab_size})."
  688. if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
  689. msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
  690. if vocab.vocab_size < params.n_vocab:
  691. msg += " Possibly try using the --padvocab option."
  692. raise Exception(msg)
  693. class OutputFile:
  694. def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
  695. self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
  696. def add_meta_arch(self, params: Params) -> None:
  697. name = "LLaMA"
  698. # TODO: better logic to determine model name
  699. if params.n_ctx == 4096:
  700. name = "LLaMA v2"
  701. elif params.path_model is not None:
  702. name = str(params.path_model.parent).split('/')[-1]
  703. self.gguf.add_name (name)
  704. self.gguf.add_context_length (params.n_ctx)
  705. self.gguf.add_embedding_length (params.n_embd)
  706. self.gguf.add_block_count (params.n_layer)
  707. self.gguf.add_feed_forward_length (params.n_ff)
  708. self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
  709. self.gguf.add_head_count (params.n_head)
  710. self.gguf.add_head_count_kv (params.n_head_kv)
  711. if params.n_experts:
  712. self.gguf.add_expert_count(params.n_experts)
  713. if params.n_experts_used:
  714. self.gguf.add_expert_used_count(params.n_experts_used)
  715. if params.f_norm_eps:
  716. self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
  717. else:
  718. raise ValueError('f_norm_eps is None')
  719. if params.f_rope_freq_base is not None:
  720. self.gguf.add_rope_freq_base(params.f_rope_freq_base)
  721. if params.rope_scaling_type:
  722. assert params.f_rope_scale is not None
  723. self.gguf.add_rope_scaling_type(params.rope_scaling_type)
  724. self.gguf.add_rope_scaling_factor(params.f_rope_scale)
  725. if params.n_orig_ctx is not None:
  726. self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
  727. if params.rope_finetuned is not None:
  728. self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
  729. if params.ftype is not None:
  730. self.gguf.add_file_type(params.ftype)
  731. def add_meta_vocab(self, vocab: Vocab) -> None:
  732. tokens = []
  733. scores = []
  734. toktypes = []
  735. # NOTE: `all_tokens` returns the base vocabulary and added tokens
  736. for text, score, toktype in vocab.all_tokens():
  737. tokens.append(text)
  738. scores.append(score)
  739. toktypes.append(toktype)
  740. vocab_type = vocab.get_vocab_type()
  741. self.gguf.add_tokenizer_model(vocab_type)
  742. self.gguf.add_token_list(tokens)
  743. self.gguf.add_token_scores(scores)
  744. self.gguf.add_token_types(toktypes)
  745. def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
  746. svocab.add_to_gguf(self.gguf)
  747. def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
  748. n_elements = int(np.prod(tensor.shape))
  749. raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
  750. data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
  751. data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
  752. self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
  753. def write_meta(self) -> None:
  754. self.gguf.write_header_to_file()
  755. self.gguf.write_kv_data_to_file()
  756. def write_tensor_info(self) -> None:
  757. self.gguf.write_ti_data_to_file()
  758. def close(self) -> None:
  759. self.gguf.close()
  760. @staticmethod
  761. def write_vocab_only(
  762. fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
  763. endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
  764. pad_vocab: bool = False,
  765. ) -> None:
  766. check_vocab_size(params, vocab, pad_vocab = pad_vocab)
  767. of = OutputFile(fname_out, endianess=endianess)
  768. # meta data
  769. of.add_meta_arch(params)
  770. of.add_meta_vocab(vocab)
  771. of.add_meta_special_vocab(svocab)
  772. of.write_meta()
  773. of.close()
  774. @staticmethod
  775. def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]:
  776. name, lazy_tensor = item
  777. tensor = lazy_tensor.load().to_ggml()
  778. return (lazy_tensor.data_type, tensor.ndarray)
  779. @staticmethod
  780. def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray:
  781. dt, arr = item
  782. if not isinstance(dt, QuantizedDataType):
  783. return arr
  784. return dt.quantize(arr)
  785. @staticmethod
  786. def write_all(
  787. fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
  788. concurrency: int = DEFAULT_CONCURRENCY,
  789. endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
  790. pad_vocab: bool = False,
  791. ) -> None:
  792. check_vocab_size(params, vocab, pad_vocab = pad_vocab)
  793. of = OutputFile(fname_out, endianess=endianess)
  794. # meta data
  795. of.add_meta_arch(params)
  796. of.add_meta_vocab(vocab)
  797. of.add_meta_special_vocab(svocab)
  798. # tensor info
  799. for name, lazy_tensor in model.items():
  800. of.add_tensor_info(name, lazy_tensor)
  801. of.write_meta()
  802. of.write_tensor_info()
  803. # tensor data
  804. ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
  805. if ftype == GGMLFileType.MostlyQ8_0:
  806. ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
  807. else:
  808. ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
  809. start = time.time()
  810. for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
  811. elapsed = time.time() - start
  812. size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
  813. padi = len(str(len(model)))
  814. print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
  815. of.gguf.write_tensor_data(ndarray)
  816. of.close()
  817. def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
  818. wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
  819. if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
  820. return GGMLFileType.AllF32
  821. if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
  822. return GGMLFileType.MostlyF16
  823. if output_type_str == "q8_0":
  824. return GGMLFileType.MostlyQ8_0
  825. name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
  826. raise Exception(f"Unexpected combination of types: {name_to_type}")
  827. def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
  828. return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
  829. for (name, tensor) in model.items()}
  830. def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
  831. tmap = gguf.TensorNameMap(ARCH, params.n_layer)
  832. should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
  833. tmp = model
  834. # HF models permut or pack some of the tensors, so we need to undo that
  835. for i in itertools.count():
  836. if f"model.layers.{i}.self_attn.q_proj.weight" in model:
  837. print(f"Permuting layer {i}")
  838. tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head)
  839. tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
  840. # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
  841. elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
  842. print(f"Unpacking and permuting layer {i}")
  843. tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
  844. tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
  845. tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
  846. del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
  847. else:
  848. break
  849. out: LazyModel = {}
  850. for name, lazy_tensor in model.items():
  851. tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
  852. if name_new is None:
  853. raise Exception(f"Unexpected tensor name: {name}")
  854. if tensor_type in should_skip:
  855. print(f"skipping tensor {name_new}")
  856. continue
  857. print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
  858. out[name_new] = lazy_tensor
  859. return out
  860. def nth_multifile_path(path: Path, n: int) -> Path | None:
  861. '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
  862. the nth path in the model.
  863. '''
  864. # Support the following patterns:
  865. patterns: list[tuple[str, str]] = [
  866. # - x.00.pth, x.01.pth, etc.
  867. (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
  868. # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
  869. (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'),
  870. # x.bin, x.bin.1, etc.
  871. (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}')
  872. ]
  873. for regex, replacement in patterns:
  874. if re.search(regex, path.name):
  875. new_path = path.with_name(re.sub(regex, replacement, path.name))
  876. if new_path.exists():
  877. return new_path
  878. return None
  879. def find_multifile_paths(path: Path) -> list[Path]:
  880. '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
  881. the whole list of paths in the model.
  882. '''
  883. ret: list[Path] = []
  884. for i in itertools.count():
  885. nth_path = nth_multifile_path(path, i)
  886. if nth_path is None:
  887. break
  888. ret.append(nth_path)
  889. if not ret:
  890. # No matches. This should only happen if the file was named, e.g.,
  891. # foo.0, and there was no file named foo. Oh well, try to process it
  892. # as a single file.
  893. return [path]
  894. return ret
  895. def load_some_model(path: Path) -> ModelPlus:
  896. '''Load a model of any supported format.'''
  897. # Be extra-friendly and accept either a file or a directory:
  898. if path.is_dir():
  899. # Check if it's a set of safetensors files first
  900. globs = ["model-00001-of-*.safetensors", "model.safetensors"]
  901. files = [file for glob in globs for file in path.glob(glob)]
  902. if not files:
  903. # Try the PyTorch patterns too, with lower priority
  904. globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
  905. files = [file for glob in globs for file in path.glob(glob)]
  906. if not files:
  907. raise Exception(f"Can't find model in directory {path}")
  908. if len(files) > 1:
  909. raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}")
  910. path = files[0]
  911. paths = find_multifile_paths(path)
  912. models_plus: list[ModelPlus] = []
  913. for path in paths:
  914. print(f"Loading model file {path}")
  915. models_plus.append(lazy_load_file(path))
  916. model_plus = merge_multifile_models(models_plus)
  917. return model_plus
  918. def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
  919. path2 = path / vocab_file
  920. # Use `.parent` instead of /.. to handle the symlink case better.
  921. path3 = path.parent / vocab_file
  922. if path2.exists():
  923. return path2
  924. if path3.exists():
  925. return path3
  926. return None
  927. def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
  928. namestr = {
  929. GGMLFileType.AllF32: "f32",
  930. GGMLFileType.MostlyF16: "f16",
  931. GGMLFileType.MostlyQ8_0:"q8_0",
  932. }[file_type]
  933. ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
  934. if ret in model_paths:
  935. sys.stderr.write(
  936. f"Error: Default output path ({ret}) would overwrite the input. "
  937. "Please explicitly specify a path using --outfile.\n")
  938. sys.exit(1)
  939. return ret
  940. def do_dump_model(model_plus: ModelPlus) -> None:
  941. print(f"model_plus.paths = {model_plus.paths!r}")
  942. print(f"model_plus.format = {model_plus.format!r}")
  943. print(f"model_plus.vocab = {model_plus.vocab!r}")
  944. for name, lazy_tensor in model_plus.model.items():
  945. print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
  946. def main(args_in: list[str] | None = None) -> None:
  947. output_choices = ["f32", "f16"]
  948. if np.uint32(1) == np.uint32(1).newbyteorder("<"):
  949. # We currently only support Q8_0 output on little endian systems.
  950. output_choices.append("q8_0")
  951. parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
  952. parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None)
  953. parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
  954. parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
  955. parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
  956. parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
  957. parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
  958. parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
  959. parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
  960. parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
  961. parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
  962. parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
  963. parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
  964. args = parser.parse_args(args_in)
  965. if args.awq_path:
  966. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  967. from awq.apply_awq import add_scale_weights
  968. tmp_model_path = args.model / "weighted_model"
  969. if tmp_model_path.is_dir():
  970. print(f"{tmp_model_path} exists as a weighted model.")
  971. else:
  972. tmp_model_path.mkdir(parents=True, exist_ok=True)
  973. print("Saving new weighted model ...")
  974. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  975. print(f"Saved weighted model at {tmp_model_path}.")
  976. args.model = tmp_model_path
  977. if args.dump_single:
  978. model_plus = lazy_load_file(args.model)
  979. do_dump_model(model_plus)
  980. return
  981. if not args.vocab_only:
  982. model_plus = load_some_model(args.model)
  983. else:
  984. model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
  985. if args.dump:
  986. do_dump_model(model_plus)
  987. return
  988. endianess = gguf.GGUFEndian.LITTLE
  989. if args.bigendian:
  990. endianess = gguf.GGUFEndian.BIG
  991. params = Params.load(model_plus)
  992. if params.n_ctx == -1:
  993. if args.ctx is None:
  994. raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
  995. "Please specify one with --ctx:\n"
  996. " - LLaMA v1: --ctx 2048\n"
  997. " - LLaMA v2: --ctx 4096\n")
  998. params.n_ctx = args.ctx
  999. if args.outtype:
  1000. params.ftype = {
  1001. "f32": GGMLFileType.AllF32,
  1002. "f16": GGMLFileType.MostlyF16,
  1003. "q8_0": GGMLFileType.MostlyQ8_0,
  1004. }[args.outtype]
  1005. print(f"params = {params}")
  1006. vocab: Vocab
  1007. if args.vocab_only:
  1008. if not args.outfile:
  1009. raise ValueError("need --outfile if using --vocab-only")
  1010. # FIXME: Try to respect vocab_dir somehow?
  1011. vocab = VocabLoader(params, args.vocab_dir or args.model)
  1012. special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
  1013. load_merges = True,
  1014. n_vocab = vocab.vocab_size)
  1015. outfile = args.outfile
  1016. OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
  1017. endianess = endianess, pad_vocab = args.padvocab)
  1018. print(f"Wrote {outfile}")
  1019. return
  1020. if model_plus.vocab is not None and args.vocab_dir is None:
  1021. vocab = model_plus.vocab
  1022. else:
  1023. vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
  1024. vocab = VocabLoader(params, vocab_dir)
  1025. # FIXME: Try to respect vocab_dir somehow?
  1026. print(f"Vocab info: {vocab}")
  1027. special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
  1028. load_merges = True,
  1029. n_vocab = vocab.vocab_size)
  1030. print(f"Special vocab info: {special_vocab}")
  1031. model = model_plus.model
  1032. model = convert_model_names(model, params)
  1033. ftype = pick_output_type(model, args.outtype)
  1034. model = convert_to_output_type(model, ftype)
  1035. outfile = args.outfile or default_outfile(model_plus.paths, ftype)
  1036. params.ftype = ftype
  1037. print(f"Writing {outfile}, format {ftype}")
  1038. OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
  1039. concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
  1040. print(f"Wrote {outfile}")
  1041. if __name__ == '__main__':
  1042. main()