convert.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import argparse
  4. import concurrent.futures
  5. import enum
  6. import faulthandler
  7. import functools
  8. import itertools
  9. import json
  10. import math
  11. import mmap
  12. import os
  13. import pickle
  14. import re
  15. import signal
  16. import struct
  17. import sys
  18. import textwrap
  19. import time
  20. import zipfile
  21. from abc import ABC, abstractmethod
  22. from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  23. from dataclasses import dataclass
  24. from pathlib import Path
  25. from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
  26. import numpy as np
  27. from sentencepiece import SentencePieceProcessor
  28. if 'NO_LOCAL_GGUF' not in os.environ:
  29. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  30. import gguf
  31. if TYPE_CHECKING:
  32. from typing_extensions import Self, TypeAlias
  33. if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
  34. faulthandler.register(signal.SIGUSR1)
  35. NDArray: TypeAlias = 'np.ndarray[Any, Any]'
  36. ARCH = gguf.MODEL_ARCH.LLAMA
  37. DEFAULT_CONCURRENCY = 8
  38. ADDED_TOKENS_FILE = 'added_tokens.json'
  39. FAST_TOKENIZER_FILE = 'tokenizer.json'
  40. #
  41. # data types
  42. #
  43. @dataclass(frozen=True)
  44. class DataType:
  45. name: str
  46. dtype: np.dtype[Any]
  47. valid_conversions: list[str]
  48. def elements_to_bytes(self, n_elements: int) -> int:
  49. return n_elements * self.dtype.itemsize
  50. @dataclass(frozen=True)
  51. class UnquantizedDataType(DataType):
  52. pass
  53. DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
  54. DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
  55. DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
  56. DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
  57. @dataclass(frozen=True)
  58. class QuantizedDataType(DataType):
  59. block_size: int
  60. quantized_dtype: np.dtype[Any]
  61. ggml_type: gguf.GGMLQuantizationType
  62. def quantize(self, arr: NDArray) -> NDArray:
  63. raise NotImplementedError(f'Quantization for {self.name} not implemented')
  64. def elements_to_bytes(self, n_elements: int) -> int:
  65. assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
  66. return self.quantized_dtype.itemsize * (n_elements // self.block_size)
  67. @dataclass(frozen=True)
  68. class Q8_0QuantizedDataType(QuantizedDataType):
  69. # Mini Q8_0 quantization in Python!
  70. def quantize(self, arr: NDArray) -> NDArray:
  71. assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
  72. assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
  73. n_blocks = arr.size // self.block_size
  74. blocks = arr.reshape((n_blocks, self.block_size))
  75. # Much faster implementation of block quantization contributed by @Cebtenzzre
  76. def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]:
  77. d = abs(blocks).max(axis = 1) / np.float32(127)
  78. with np.errstate(divide = 'ignore'):
  79. qs = (blocks / d[:, None]).round()
  80. qs[d == 0] = 0
  81. yield from zip(d, qs)
  82. return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
  83. DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
  84. dtype = np.dtype(np.float32), valid_conversions = [],
  85. ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
  86. quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
  87. # Quantized types skipped here because they may also map to np.float32
  88. NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {}
  89. for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
  90. if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
  91. raise ValueError(f'Invalid duplicate data type {dt}')
  92. NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
  93. SAFETENSORS_DATA_TYPES: dict[str, DataType] = {
  94. 'BF16': DT_BF16,
  95. 'F16': DT_F16,
  96. 'F32': DT_F32,
  97. 'I32': DT_I32,
  98. }
  99. # TODO: match this with `llama_ftype`
  100. # TODO: rename to LLAMAFileType
  101. # TODO: move to `gguf.py`
  102. class GGMLFileType(enum.IntEnum):
  103. AllF32 = 0
  104. MostlyF16 = 1 # except 1d tensors
  105. MostlyQ8_0 = 7 # except 1d tensors
  106. def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
  107. dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
  108. if dt is None:
  109. raise ValueError(self)
  110. # Convert all 1D tensors to F32. Most of the codebase that takes in 1D tensors only handles F32 tensors, and most of the outputs tensors are F32.
  111. # Also The 1d tensors aren't much of a performance/size issue. So instead of having to have separate F32 and F16 implementations of both, just convert everything to F32 for now.
  112. return dt if len(tensor.shape) > 1 else DT_F32
  113. GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
  114. GGMLFileType.AllF32 : DT_F32,
  115. GGMLFileType.MostlyF16 : DT_F16,
  116. GGMLFileType.MostlyQ8_0: DT_Q8_0,
  117. }
  118. #
  119. # hparams loading
  120. #
  121. @dataclass
  122. class Params:
  123. n_vocab: int
  124. n_embd: int
  125. n_layer: int
  126. n_ctx: int
  127. n_ff: int
  128. n_head: int
  129. n_head_kv: int
  130. n_experts: int | None = None
  131. n_experts_used: int | None = None
  132. f_norm_eps: float | None = None
  133. rope_scaling_type: gguf.RopeScalingType | None = None
  134. f_rope_freq_base: float | None = None
  135. f_rope_scale: float | None = None
  136. n_orig_ctx: int | None = None
  137. rope_finetuned: bool | None = None
  138. ftype: GGMLFileType | None = None
  139. # path to the directory containing the model files
  140. path_model: Path | None = None
  141. @staticmethod
  142. def guessed(model: LazyModel) -> Params:
  143. # try transformer naming first
  144. n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
  145. # try transformer naming first
  146. if "model.layers.0.self_attn.q_proj.weight" in model:
  147. n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model)
  148. elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming
  149. n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model)
  150. else:
  151. n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
  152. if n_layer < 1:
  153. msg = """\
  154. failed to guess 'n_layer'. This model is unknown or unsupported.
  155. Suggestion: provide 'config.json' of the model in the same directory containing model files."""
  156. raise KeyError(textwrap.dedent(msg))
  157. n_head = n_embd // 128 # guessed
  158. n_mult = 256 # guessed
  159. # TODO: verify this
  160. n_ff = int(2 * (4 * n_embd) / 3)
  161. n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
  162. return Params(
  163. n_vocab = n_vocab,
  164. n_embd = n_embd,
  165. n_layer = n_layer,
  166. n_ctx = -1,
  167. n_ff = n_ff,
  168. n_head = n_head,
  169. n_head_kv = n_head,
  170. f_norm_eps = 1e-5,
  171. )
  172. @staticmethod
  173. def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
  174. with open(config_path) as f:
  175. config = json.load(f)
  176. rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
  177. rope_scaling = config.get("rope_scaling")
  178. if rope_scaling is not None and (typ := rope_scaling.get("type")):
  179. rope_factor = rope_scaling.get("factor")
  180. f_rope_scale = rope_factor
  181. if typ == "linear":
  182. rope_scaling_type = gguf.RopeScalingType.LINEAR
  183. elif typ == "yarn":
  184. rope_scaling_type = gguf.RopeScalingType.YARN
  185. n_orig_ctx = rope_scaling['original_max_position_embeddings']
  186. rope_finetuned = rope_scaling['finetuned']
  187. else:
  188. raise NotImplementedError(f'Unknown rope scaling type: {typ}')
  189. if "max_sequence_length" in config:
  190. n_ctx = config["max_sequence_length"]
  191. elif "max_position_embeddings" in config:
  192. n_ctx = config["max_position_embeddings"]
  193. else:
  194. msg = """\
  195. failed to guess 'n_ctx'. This model is unknown or unsupported.
  196. Suggestion: provide 'config.json' of the model in the same directory containing model files."""
  197. raise KeyError(textwrap.dedent(msg))
  198. n_experts = None
  199. n_experts_used = None
  200. if "num_local_experts" in config:
  201. n_experts = config["num_local_experts"]
  202. n_experts_used = config["num_experts_per_tok"]
  203. return Params(
  204. n_vocab = config["vocab_size"],
  205. n_embd = config["hidden_size"],
  206. n_layer = config["num_hidden_layers"],
  207. n_ctx = n_ctx,
  208. n_ff = config["intermediate_size"],
  209. n_head = (n_head := config["num_attention_heads"]),
  210. n_head_kv = config.get("num_key_value_heads", n_head),
  211. n_experts = n_experts,
  212. n_experts_used = n_experts_used,
  213. f_norm_eps = config["rms_norm_eps"],
  214. f_rope_freq_base = config.get("rope_theta"),
  215. rope_scaling_type = rope_scaling_type,
  216. f_rope_scale = f_rope_scale,
  217. n_orig_ctx = n_orig_ctx,
  218. rope_finetuned = rope_finetuned,
  219. )
  220. # LLaMA v2 70B params.json
  221. # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
  222. @staticmethod
  223. def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
  224. with open(config_path) as f:
  225. config = json.load(f)
  226. n_experts = None
  227. n_experts_used = None
  228. f_rope_freq_base = None
  229. # hack to determine LLaMA v1 vs v2 vs CodeLlama
  230. if config.get("moe"):
  231. # Mixtral
  232. n_ctx = 32768
  233. elif config.get("rope_theta") == 1000000:
  234. # CodeLlama
  235. n_ctx = 16384
  236. elif config["norm_eps"] == 1e-05:
  237. # LLaMA v2
  238. n_ctx = 4096
  239. else:
  240. # LLaMA v1
  241. n_ctx = 2048
  242. if "layers.0.feed_forward.w1.weight" in model:
  243. n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
  244. if config.get("moe"):
  245. n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
  246. n_experts = config["moe"]["num_experts"]
  247. n_experts_used = config["moe"]["num_experts_per_tok"]
  248. f_rope_freq_base = 1e6
  249. return Params(
  250. n_vocab = model["tok_embeddings.weight"].shape[0],
  251. n_embd = config["dim"],
  252. n_layer = config["n_layers"],
  253. n_ctx = n_ctx,
  254. n_ff = n_ff,
  255. n_head = (n_head := config["n_heads"]),
  256. n_head_kv = config.get("n_kv_heads", n_head),
  257. n_experts = n_experts,
  258. n_experts_used = n_experts_used,
  259. f_norm_eps = config["norm_eps"],
  260. f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
  261. )
  262. @staticmethod
  263. def load(model_plus: ModelPlus) -> Params:
  264. hf_config_path = model_plus.paths[0].parent / "config.json"
  265. orig_config_path = model_plus.paths[0].parent / "params.json"
  266. if hf_config_path.exists():
  267. params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
  268. elif orig_config_path.exists():
  269. params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
  270. elif model_plus.format != 'none':
  271. params = Params.guessed(model_plus.model)
  272. else:
  273. raise ValueError('Cannot guess params when model format is none')
  274. params.path_model = model_plus.paths[0].parent
  275. return params
  276. #
  277. # vocab
  278. #
  279. @runtime_checkable
  280. class BaseVocab(Protocol):
  281. tokenizer_model: ClassVar[str]
  282. name: ClassVar[str]
  283. class NoVocab(BaseVocab):
  284. tokenizer_model = "no_vocab"
  285. name = "no_vocab"
  286. def __repr__(self) -> str:
  287. return "<NoVocab for a model without integrated vocabulary>"
  288. @runtime_checkable
  289. class Vocab(BaseVocab, Protocol):
  290. vocab_size: int
  291. added_tokens_dict: dict[str, int]
  292. added_tokens_list: list[str]
  293. fname_tokenizer: Path
  294. def __init__(self, base_path: Path): ...
  295. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
  296. class BpeVocab(Vocab):
  297. tokenizer_model = "gpt2"
  298. name = "bpe"
  299. def __init__(self, base_path: Path):
  300. added_tokens: dict[str, int] = {}
  301. if (fname_tokenizer := base_path / 'vocab.json').exists():
  302. # "slow" tokenizer
  303. with open(fname_tokenizer, encoding="utf-8") as f:
  304. self.vocab = json.load(f)
  305. try:
  306. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
  307. with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
  308. added_tokens = json.load(f)
  309. except FileNotFoundError:
  310. pass
  311. else:
  312. # "fast" tokenizer
  313. fname_tokenizer = base_path / FAST_TOKENIZER_FILE
  314. # if this fails, FileNotFoundError propagates to caller
  315. with open(fname_tokenizer, encoding="utf-8") as f:
  316. tokenizer_json = json.load(f)
  317. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  318. if (
  319. tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
  320. or tokenizer_json['decoder']['type'] != 'ByteLevel'
  321. ):
  322. raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
  323. self.vocab = tokenizer_model["vocab"]
  324. if (added := tokenizer_json.get('added_tokens')) is not None:
  325. # Added tokens here can be duplicates of the main vocabulary.
  326. added_tokens = {item['content']: item['id']
  327. for item in added
  328. if item['content'] not in self.vocab}
  329. vocab_size = len(self.vocab)
  330. expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
  331. actual_ids = sorted(added_tokens.values())
  332. if expected_ids != actual_ids:
  333. expected_end_id = vocab_size + len(actual_ids) - 1
  334. raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
  335. f"{vocab_size} - {expected_end_id}; got {actual_ids}")
  336. items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
  337. self.added_tokens_dict = added_tokens
  338. self.added_tokens_list = [text for (text, idx) in items]
  339. self.vocab_size_base = vocab_size
  340. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  341. self.fname_tokenizer = fname_tokenizer
  342. def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  343. reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
  344. for i, _ in enumerate(self.vocab):
  345. yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
  346. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  347. for text in self.added_tokens_list:
  348. score = -1000.0
  349. yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
  350. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  351. yield from self.bpe_tokens()
  352. yield from self.added_tokens()
  353. def __repr__(self) -> str:
  354. return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  355. class SentencePieceVocab(Vocab):
  356. tokenizer_model = "llama"
  357. name = "spm"
  358. def __init__(self, base_path: Path):
  359. added_tokens: dict[str, int] = {}
  360. if (fname_tokenizer := base_path / 'tokenizer.model').exists():
  361. # normal location
  362. try:
  363. with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
  364. added_tokens = json.load(f)
  365. except FileNotFoundError:
  366. pass
  367. elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
  368. # not found in alternate location either
  369. raise FileNotFoundError('Cannot find tokenizer.model')
  370. self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
  371. vocab_size = self.sentencepiece_tokenizer.vocab_size()
  372. new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
  373. expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
  374. actual_new_ids = sorted(new_tokens.keys())
  375. if expected_new_ids != actual_new_ids:
  376. raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
  377. # Token pieces that were added to the base vocabulary.
  378. self.added_tokens_dict = added_tokens
  379. self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
  380. self.vocab_size_base = vocab_size
  381. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  382. self.fname_tokenizer = fname_tokenizer
  383. def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  384. tokenizer = self.sentencepiece_tokenizer
  385. for i in range(tokenizer.vocab_size()):
  386. piece = tokenizer.id_to_piece(i)
  387. text = piece.encode("utf-8")
  388. score: float = tokenizer.get_score(i)
  389. toktype = gguf.TokenType.NORMAL
  390. if tokenizer.is_unknown(i):
  391. toktype = gguf.TokenType.UNKNOWN
  392. if tokenizer.is_control(i):
  393. toktype = gguf.TokenType.CONTROL
  394. # NOTE: I think added_tokens are user defined.
  395. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
  396. # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
  397. if tokenizer.is_unused(i):
  398. toktype = gguf.TokenType.UNUSED
  399. if tokenizer.is_byte(i):
  400. toktype = gguf.TokenType.BYTE
  401. yield text, score, toktype
  402. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  403. for text in self.added_tokens_list:
  404. score = -1000.0
  405. yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
  406. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  407. yield from self.sentencepiece_tokens()
  408. yield from self.added_tokens()
  409. def __repr__(self) -> str:
  410. return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  411. class LlamaHfVocab(Vocab):
  412. tokenizer_model = "llama"
  413. name = "hfft"
  414. def __init__(self, base_path: Path):
  415. fname_tokenizer = base_path / FAST_TOKENIZER_FILE
  416. # if this fails, FileNotFoundError propagates to caller
  417. with open(fname_tokenizer, encoding='utf-8') as f:
  418. tokenizer_json = json.load(f)
  419. # pre-check so we know if we need transformers
  420. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  421. if (
  422. tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
  423. or tokenizer_json['decoder']['type'] != 'Sequence'
  424. ):
  425. raise FileNotFoundError('Cannot find Llama BPE tokenizer')
  426. try:
  427. from transformers import AutoTokenizer
  428. except ImportError as e:
  429. raise ImportError(
  430. "To use LlamaHfVocab, please install the `transformers` package. "
  431. "You can install it with `pip install transformers`."
  432. ) from e
  433. # Allow the tokenizer to default to slow or fast versions.
  434. # Explicitly set tokenizer to use local paths.
  435. self.tokenizer = AutoTokenizer.from_pretrained(
  436. base_path,
  437. cache_dir=base_path,
  438. local_files_only=True,
  439. )
  440. assert self.tokenizer.is_fast # assume tokenizer.json is used
  441. # Initialize lists and dictionaries for added tokens
  442. self.added_tokens_list = []
  443. self.added_tokens_dict = dict()
  444. self.added_tokens_ids = set()
  445. # Process added tokens
  446. for tok, tokidx in sorted(
  447. self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
  448. ):
  449. # Only consider added tokens that are not in the base vocabulary
  450. if tokidx >= self.tokenizer.vocab_size:
  451. self.added_tokens_list.append(tok)
  452. self.added_tokens_dict[tok] = tokidx
  453. self.added_tokens_ids.add(tokidx)
  454. # Store special tokens and their IDs
  455. self.specials = {
  456. tok: self.tokenizer.get_vocab()[tok]
  457. for tok in self.tokenizer.all_special_tokens
  458. }
  459. self.special_ids = set(self.tokenizer.all_special_ids)
  460. # Set vocabulary sizes
  461. self.vocab_size_base = self.tokenizer.vocab_size
  462. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  463. self.fname_tokenizer = fname_tokenizer
  464. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  465. reverse_vocab = {
  466. id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
  467. }
  468. for token_id in range(self.vocab_size_base):
  469. # Skip processing added tokens here
  470. if token_id in self.added_tokens_ids:
  471. continue
  472. # Convert token text to bytes
  473. token_text = reverse_vocab[token_id].encode("utf-8")
  474. # Yield token text, score, and type
  475. yield token_text, self.get_token_score(token_id), self.get_token_type(
  476. token_id, token_text, self.special_ids # Reuse already stored special IDs
  477. )
  478. def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
  479. # Special case for byte tokens
  480. if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  481. return gguf.TokenType.BYTE
  482. # Determine token type based on whether it's a special token
  483. return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
  484. def get_token_score(self, token_id: int) -> float:
  485. # Placeholder for actual logic to determine the token's score
  486. # This needs to be implemented based on specific requirements
  487. return -1000.0 # Default score
  488. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  489. for text in self.added_tokens_list:
  490. if text in self.specials:
  491. toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
  492. score = self.get_token_score(self.specials[text])
  493. else:
  494. toktype = gguf.TokenType.USER_DEFINED
  495. score = -1000.0
  496. yield text.encode("utf-8"), score, toktype
  497. def has_newline_token(self):
  498. return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
  499. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  500. yield from self.hf_tokens()
  501. yield from self.added_tokens()
  502. def __repr__(self) -> str:
  503. return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  504. #
  505. # data loading
  506. # TODO: reuse (probably move to gguf.py?)
  507. #
  508. def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
  509. # print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
  510. if n_head_kv is not None and n_head != n_head_kv:
  511. n_head = n_head_kv
  512. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  513. .swapaxes(1, 2)
  514. .reshape(weights.shape))
  515. class Tensor(ABC):
  516. ndarray: NDArray
  517. data_type: DataType
  518. @abstractmethod
  519. def astype(self, data_type: DataType) -> Self: ...
  520. @abstractmethod
  521. def permute(self, n_head: int, n_head_kv: int) -> Self: ...
  522. @abstractmethod
  523. def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
  524. @abstractmethod
  525. def part(self, n_part: int) -> Self: ...
  526. @abstractmethod
  527. def to_ggml(self) -> GGMLCompatibleTensor: ...
  528. def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
  529. assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
  530. fp32_arr = bf16_arr.astype(np.uint32) << 16
  531. return fp32_arr.view(np.float32)
  532. class UnquantizedTensor(Tensor):
  533. def __init__(self, ndarray: NDArray):
  534. assert isinstance(ndarray, np.ndarray)
  535. self.ndarray = ndarray
  536. self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
  537. def astype(self, data_type: DataType) -> UnquantizedTensor:
  538. dtype = data_type.dtype
  539. if self.data_type == DT_BF16:
  540. self.ndarray = bf16_to_fp32(self.ndarray)
  541. return UnquantizedTensor(self.ndarray.astype(dtype))
  542. def to_ggml(self) -> Self:
  543. return self
  544. def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
  545. r = self.ndarray.shape[0] // 3
  546. return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
  547. def part(self, n_part: int) -> UnquantizedTensor:
  548. r = self.ndarray.shape[0] // 3
  549. return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
  550. def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor:
  551. return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
  552. def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray:
  553. tensor = lazy_tensor.load()
  554. assert isinstance(tensor, UnquantizedTensor)
  555. # double-check:
  556. actual_shape = list(tensor.ndarray.shape)
  557. assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape)
  558. if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype:
  559. if convert:
  560. tensor.ndarray = tensor.ndarray.astype(expected_dtype)
  561. else:
  562. raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}')
  563. return tensor.ndarray
  564. GGMLCompatibleTensor = UnquantizedTensor
  565. @dataclass
  566. class LazyTensor:
  567. _load: Callable[[], Tensor]
  568. shape: list[int]
  569. data_type: DataType
  570. description: str
  571. def load(self) -> Tensor:
  572. ret = self._load()
  573. # Should be okay if it maps to the same numpy type?
  574. assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
  575. (self.data_type, ret.data_type, self.description)
  576. return ret
  577. def astype(self, data_type: DataType) -> LazyTensor:
  578. self.validate_conversion_to(data_type)
  579. def load() -> Tensor:
  580. return self.load().astype(data_type)
  581. return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
  582. def validate_conversion_to(self, data_type: DataType) -> None:
  583. if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
  584. raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
  585. LazyModel: TypeAlias = 'dict[str, LazyTensor]'
  586. @dataclass
  587. class ModelPlus:
  588. model: LazyModel
  589. paths: list[Path] # Where this was read from.
  590. format: Literal['ggml', 'torch', 'safetensors', 'none']
  591. vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
  592. def merge_sharded(models: list[LazyModel]) -> LazyModel:
  593. # Original LLaMA models have each file contain one part of each tensor.
  594. # Use a dict instead of a set to preserve order.
  595. names = {name: None for model in models for name in model}
  596. def convert(name: str) -> LazyTensor:
  597. lazy_tensors = [model[name] for model in models]
  598. if len(lazy_tensors) == 1:
  599. # only one file; don't go through this procedure since there might
  600. # be quantized tensors
  601. return lazy_tensors[0]
  602. if len(lazy_tensors[0].shape) == 1:
  603. # the tensor is just duplicated in every file
  604. return lazy_tensors[0]
  605. if name.startswith('tok_embeddings.') or \
  606. name.endswith('.attention.wo.weight') or \
  607. name.endswith('.feed_forward.w2.weight'):
  608. # split by columns
  609. axis = 1
  610. else:
  611. # split by rows
  612. axis = 0
  613. concatenated_shape = list(lazy_tensors[0].shape)
  614. concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors)
  615. def load() -> UnquantizedTensor:
  616. ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
  617. concatenated = np.concatenate(ndarrays, axis=axis)
  618. return UnquantizedTensor(concatenated)
  619. description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
  620. return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
  621. return {name: convert(name) for name in names}
  622. def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
  623. formats = set(mp.format for mp in models_plus)
  624. assert len(formats) == 1, "different formats?"
  625. format = formats.pop()
  626. paths = [path for mp in models_plus for path in mp.paths]
  627. # Use the first non-None vocab, if any.
  628. try:
  629. vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None)
  630. except StopIteration:
  631. vocab = None
  632. if any("model.embed_tokens.weight" in mp.model for mp in models_plus):
  633. # Transformers models put different tensors in different files, but
  634. # don't split individual tensors between files.
  635. model: LazyModel = {}
  636. for mp in models_plus:
  637. model.update(mp.model)
  638. else:
  639. model = merge_sharded([mp.model for mp in models_plus])
  640. return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types
  641. def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
  642. def load() -> Tensor:
  643. return lazy_tensor.load().permute(n_head, n_head_kv)
  644. return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
  645. def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor:
  646. def load() -> Tensor:
  647. return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv)
  648. s = lazy_tensor.shape.copy()
  649. s[0] = s[0] // 3
  650. return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
  651. def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
  652. def load() -> Tensor:
  653. return lazy_tensor.load().part(n_part)
  654. s = lazy_tensor.shape.copy()
  655. s[0] = s[0] // 3
  656. return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
  657. def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor:
  658. def load() -> Tensor:
  659. tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors]
  660. return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors]))
  661. s = lazy_tensors[0].shape.copy()
  662. s.insert(0, len(lazy_tensors))
  663. return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors))
  664. # Functionality that simulates `torch.load` but where individual tensors are
  665. # only loaded into memory on demand, not all at once.
  666. # PyTorch can't do this natively as of time of writing:
  667. # - https://github.com/pytorch/pytorch/issues/64327
  668. # This allows us to de-shard without multiplying RAM usage, and also
  669. # conveniently drops the PyTorch dependency (though we still need numpy).
  670. @dataclass
  671. class LazyStorageKind:
  672. data_type: DataType
  673. @dataclass
  674. class LazyStorage:
  675. load: Callable[[int, int], NDArray]
  676. kind: LazyStorageKind
  677. description: str
  678. class LazyUnpickler(pickle.Unpickler):
  679. def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
  680. super().__init__(fp)
  681. self.data_base_path = data_base_path
  682. self.zip_file = zip_file
  683. def persistent_load(self, pid: Any) -> Any:
  684. assert pid[0] == 'storage'
  685. assert isinstance(pid[1], LazyStorageKind)
  686. data_type = pid[1].data_type
  687. filename_stem = pid[2]
  688. filename = f'{self.data_base_path}/{filename_stem}'
  689. info = self.zip_file.getinfo(filename)
  690. def load(offset: int, elm_count: int) -> NDArray:
  691. dtype = data_type.dtype
  692. with self.zip_file.open(info) as fp:
  693. fp.seek(offset * dtype.itemsize)
  694. size = elm_count * dtype.itemsize
  695. data = fp.read(size)
  696. assert len(data) == size
  697. return np.frombuffer(data, dtype)
  698. description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
  699. return LazyStorage(load=load, kind=pid[1], description=description)
  700. @staticmethod
  701. def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
  702. requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
  703. assert isinstance(storage, LazyStorage)
  704. def load() -> UnquantizedTensor:
  705. elm_count = stride[0] * size[0]
  706. return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size))
  707. description = f'pickled storage_offset={storage_offset} in {storage.description}'
  708. return LazyTensor(load, list(size), storage.kind.data_type, description)
  709. @staticmethod
  710. def rebuild_from_type_v2(func, new_type, args, state):
  711. return func(*args)
  712. CLASSES = {
  713. # getattr used here as a workaround for mypy not being smart enough to determine
  714. # the staticmethods have a __func__ attribute.
  715. ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
  716. ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
  717. ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
  718. ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
  719. ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
  720. ('torch', 'IntStorage'): LazyStorageKind(DT_I32),
  721. ('torch', 'Tensor'): LazyTensor,
  722. }
  723. def find_class(self, module: str, name: str) -> Any:
  724. if not module.startswith('torch'):
  725. return super().find_class(module, name)
  726. return self.CLASSES[(module, name)]
  727. def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
  728. zf = zipfile.ZipFile(outer_fp)
  729. pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
  730. assert len(pickle_paths) == 1, pickle_paths
  731. pickle_fp = zf.open(pickle_paths[0], 'r')
  732. unpickler = LazyUnpickler(pickle_fp,
  733. data_base_path=pickle_paths[0][:-4],
  734. zip_file=zf)
  735. model = unpickler.load()
  736. if 'model' in model: model = model['model']
  737. as_dict = dict(model.items())
  738. return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
  739. def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
  740. header_size, = struct.unpack('<Q', fp.read(8))
  741. header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size))
  742. # Use mmap for the actual data to avoid race conditions with the file offset.
  743. mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
  744. byte_buf = mapped[8 + header_size:]
  745. def convert(info: dict[str, Any]) -> LazyTensor:
  746. data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
  747. numpy_dtype = data_type.dtype
  748. shape: list[int] = info['shape']
  749. begin, end = info['data_offsets']
  750. assert 0 <= begin <= end <= len(byte_buf)
  751. assert end - begin == math.prod(shape) * numpy_dtype.itemsize
  752. buf = byte_buf[begin:end]
  753. def load() -> UnquantizedTensor:
  754. return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
  755. description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
  756. return LazyTensor(load, shape, data_type, description)
  757. model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
  758. return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
  759. def must_read(fp: IO[bytes], length: int) -> bytes:
  760. ret = fp.read(length)
  761. if len(ret) < length:
  762. raise EOFError("unexpectedly reached end of file")
  763. return ret
  764. @functools.lru_cache(maxsize=None)
  765. def lazy_load_file(path: Path) -> ModelPlus:
  766. fp = open(path, 'rb')
  767. first8 = fp.read(8)
  768. fp.seek(0)
  769. if first8[:2] == b'PK':
  770. # A zip file, i.e. PyTorch format
  771. return lazy_load_torch_file(fp, path)
  772. elif struct.unpack('<Q', first8)[0] < 16 * 1024 * 1024:
  773. # Probably safetensors
  774. return lazy_load_safetensors_file(fp, path)
  775. else:
  776. raise ValueError(f"unknown format: {path}")
  777. In = TypeVar('In')
  778. Out = TypeVar('Out')
  779. def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]:
  780. '''Parallel map, but with backpressure. If the caller doesn't call `next`
  781. fast enough, this will stop calling `func` at some point rather than
  782. letting results pile up in memory. Specifically, there is a max of one
  783. output value buffered per thread.'''
  784. if concurrency < 2:
  785. yield from map(func, iterable)
  786. # Not reached.
  787. iterable = iter(iterable)
  788. executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor]
  789. if use_processpool_executor:
  790. executor_class = ProcessPoolExecutor
  791. else:
  792. executor_class = ThreadPoolExecutor
  793. with executor_class(max_workers=max_workers) as executor:
  794. futures: list[concurrent.futures.Future[Out]] = []
  795. done = False
  796. for _ in range(concurrency):
  797. try:
  798. futures.append(executor.submit(func, next(iterable)))
  799. except StopIteration:
  800. done = True
  801. break
  802. while futures:
  803. result = futures.pop(0).result()
  804. while not done and len(futures) < concurrency:
  805. try:
  806. futures.append(executor.submit(func, next(iterable)))
  807. except StopIteration:
  808. done = True
  809. break
  810. yield result
  811. def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
  812. # Handle special case where the model's vocab size is not set
  813. if params.n_vocab == -1:
  814. raise ValueError(
  815. "The model's vocab size is set to -1 in params.json. Please update it manually."
  816. + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""),
  817. )
  818. if not isinstance(vocab, Vocab):
  819. return # model has no vocab
  820. # Check for a vocab size mismatch
  821. if params.n_vocab == vocab.vocab_size:
  822. print("Ignoring added_tokens.json since model matches vocab size without it.")
  823. return
  824. if pad_vocab and params.n_vocab > vocab.vocab_size:
  825. pad_count = params.n_vocab - vocab.vocab_size
  826. print(
  827. f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
  828. )
  829. for i in range(1, pad_count + 1):
  830. vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1
  831. vocab.added_tokens_list.append(f"<dummy{i:05}>")
  832. vocab.vocab_size = params.n_vocab
  833. return
  834. msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})."
  835. if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
  836. msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
  837. if vocab.vocab_size < params.n_vocab:
  838. msg += " Add the --pad-vocab option and try again."
  839. raise ValueError(msg)
  840. class OutputFile:
  841. def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
  842. self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
  843. def add_meta_arch(self, params: Params) -> None:
  844. name = "LLaMA"
  845. # TODO: better logic to determine model name
  846. if params.n_ctx == 4096:
  847. name = "LLaMA v2"
  848. elif params.path_model is not None:
  849. name = str(params.path_model.parent).split('/')[-1]
  850. self.gguf.add_name (name)
  851. self.gguf.add_vocab_size (params.n_vocab)
  852. self.gguf.add_context_length (params.n_ctx)
  853. self.gguf.add_embedding_length (params.n_embd)
  854. self.gguf.add_block_count (params.n_layer)
  855. self.gguf.add_feed_forward_length (params.n_ff)
  856. self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
  857. self.gguf.add_head_count (params.n_head)
  858. self.gguf.add_head_count_kv (params.n_head_kv)
  859. if params.n_experts:
  860. self.gguf.add_expert_count(params.n_experts)
  861. if params.n_experts_used:
  862. self.gguf.add_expert_used_count(params.n_experts_used)
  863. if params.f_norm_eps:
  864. self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
  865. else:
  866. raise ValueError('f_norm_eps is None')
  867. if params.f_rope_freq_base is not None:
  868. self.gguf.add_rope_freq_base(params.f_rope_freq_base)
  869. if params.rope_scaling_type:
  870. assert params.f_rope_scale is not None
  871. self.gguf.add_rope_scaling_type(params.rope_scaling_type)
  872. self.gguf.add_rope_scaling_factor(params.f_rope_scale)
  873. if params.n_orig_ctx is not None:
  874. self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
  875. if params.rope_finetuned is not None:
  876. self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
  877. if params.ftype is not None:
  878. self.gguf.add_file_type(params.ftype)
  879. def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
  880. tokens = []
  881. scores = []
  882. toktypes = []
  883. # NOTE: `all_tokens` returns the base vocabulary and added tokens
  884. for text, score, toktype in vocab.all_tokens():
  885. tokens.append(text)
  886. scores.append(score)
  887. toktypes.append(toktype)
  888. assert len(tokens) == vocab.vocab_size
  889. return tokens, scores, toktypes
  890. def add_meta_vocab(self, vocab: Vocab) -> None:
  891. # Ensure that tokenizer_model is added to the GGUF model
  892. self.gguf.add_tokenizer_model(vocab.tokenizer_model)
  893. # Extract model vocabulary for model conversion
  894. tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
  895. # Add extracted token information for model conversion
  896. self.gguf.add_token_list(tokens)
  897. self.gguf.add_token_scores(scores)
  898. self.gguf.add_token_types(toktypes)
  899. def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
  900. svocab.add_to_gguf(self.gguf)
  901. def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
  902. n_elements = int(np.prod(tensor.shape))
  903. raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
  904. data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
  905. data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
  906. self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
  907. def write_meta(self) -> None:
  908. self.gguf.write_header_to_file()
  909. self.gguf.write_kv_data_to_file()
  910. def write_tensor_info(self) -> None:
  911. self.gguf.write_ti_data_to_file()
  912. def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
  913. ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
  914. if ftype == GGMLFileType.MostlyQ8_0:
  915. ndarrays = bounded_parallel_map(
  916. OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
  917. use_processpool_executor=True,
  918. )
  919. else:
  920. ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
  921. start = time.time()
  922. for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
  923. elapsed = time.time() - start
  924. size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
  925. padi = len(str(len(model)))
  926. print(
  927. f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
  928. )
  929. self.gguf.write_tensor_data(ndarray)
  930. def close(self) -> None:
  931. self.gguf.close()
  932. @staticmethod
  933. def write_vocab_only(
  934. fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
  935. endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
  936. ) -> None:
  937. check_vocab_size(params, vocab, pad_vocab=pad_vocab)
  938. of = OutputFile(fname_out, endianess=endianess)
  939. # meta data
  940. of.add_meta_arch(params)
  941. of.add_meta_vocab(vocab)
  942. of.add_meta_special_vocab(svocab)
  943. of.write_meta()
  944. of.close()
  945. @staticmethod
  946. def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]:
  947. name, lazy_tensor = item
  948. tensor = lazy_tensor.load().to_ggml()
  949. return (lazy_tensor.data_type, tensor.ndarray)
  950. @staticmethod
  951. def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray:
  952. dt, arr = item
  953. if not isinstance(dt, QuantizedDataType):
  954. return arr
  955. return dt.quantize(arr)
  956. @staticmethod
  957. def write_all(
  958. fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
  959. concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
  960. pad_vocab: bool = False,
  961. ) -> None:
  962. check_vocab_size(params, vocab, pad_vocab=pad_vocab)
  963. of = OutputFile(fname_out, endianess=endianess)
  964. # meta data
  965. of.add_meta_arch(params)
  966. if isinstance(vocab, Vocab):
  967. of.add_meta_vocab(vocab)
  968. of.add_meta_special_vocab(svocab)
  969. else: # NoVocab
  970. of.gguf.add_tokenizer_model(vocab.tokenizer_model)
  971. # tensor info
  972. for name, lazy_tensor in model.items():
  973. of.add_tensor_info(name, lazy_tensor)
  974. of.write_meta()
  975. of.write_tensor_info()
  976. # tensor data
  977. of.write_tensor_data(ftype, model, concurrency)
  978. of.close()
  979. def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
  980. wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
  981. if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
  982. return GGMLFileType.AllF32
  983. if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
  984. return GGMLFileType.MostlyF16
  985. if output_type_str == "q8_0":
  986. return GGMLFileType.MostlyQ8_0
  987. name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
  988. raise ValueError(f"Unexpected combination of types: {name_to_type}")
  989. def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
  990. return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
  991. for (name, tensor) in model.items()}
  992. def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
  993. tmap = gguf.TensorNameMap(ARCH, params.n_layer)
  994. should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
  995. tmp = model
  996. # merge experts into one tensor
  997. if params.n_experts and params.n_experts > 0:
  998. for i_l in range(params.n_layer):
  999. for w in range(1, 4):
  1000. experts = []
  1001. for e in range(params.n_experts):
  1002. if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model:
  1003. experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"])
  1004. del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]
  1005. elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model:
  1006. experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
  1007. del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
  1008. else:
  1009. raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
  1010. tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
  1011. # HF models permut or pack some of the tensors, so we need to undo that
  1012. for i in itertools.count():
  1013. if f"model.layers.{i}.self_attn.q_proj.weight" in model:
  1014. print(f"Permuting layer {i}")
  1015. tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head)
  1016. tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
  1017. # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
  1018. elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
  1019. print(f"Unpacking and permuting layer {i}")
  1020. tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
  1021. tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
  1022. tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
  1023. del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
  1024. else:
  1025. break
  1026. out: LazyModel = {}
  1027. for name, lazy_tensor in model.items():
  1028. tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
  1029. if name_new is None:
  1030. if skip_unknown:
  1031. print(f"Unexpected tensor name: {name} - skipping")
  1032. continue
  1033. raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
  1034. if tensor_type in should_skip:
  1035. print(f"skipping tensor {name_new}")
  1036. continue
  1037. print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
  1038. out[name_new] = lazy_tensor
  1039. return out
  1040. def nth_multifile_path(path: Path, n: int) -> Path | None:
  1041. '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
  1042. the nth path in the model.
  1043. '''
  1044. # Support the following patterns:
  1045. patterns = [
  1046. # - x.00.pth, x.01.pth, etc.
  1047. (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
  1048. # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
  1049. (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'),
  1050. # x.bin, x.bin.1, etc.
  1051. (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}')
  1052. ]
  1053. for regex, replacement in patterns:
  1054. if re.search(regex, path.name):
  1055. new_path = path.with_name(re.sub(regex, replacement, path.name))
  1056. if new_path.exists():
  1057. return new_path
  1058. return None
  1059. def find_multifile_paths(path: Path) -> list[Path]:
  1060. '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
  1061. the whole list of paths in the model.
  1062. '''
  1063. ret: list[Path] = []
  1064. for i in itertools.count():
  1065. nth_path = nth_multifile_path(path, i)
  1066. if nth_path is None:
  1067. break
  1068. ret.append(nth_path)
  1069. if not ret:
  1070. # No matches. This should only happen if the file was named, e.g.,
  1071. # foo.0, and there was no file named foo. Oh well, try to process it
  1072. # as a single file.
  1073. return [path]
  1074. return ret
  1075. def load_some_model(path: Path) -> ModelPlus:
  1076. '''Load a model of any supported format.'''
  1077. # Be extra-friendly and accept either a file or a directory:
  1078. if path.is_dir():
  1079. # Check if it's a set of safetensors files first
  1080. globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"]
  1081. files = [file for glob in globs for file in path.glob(glob)]
  1082. if not files:
  1083. # Try the PyTorch patterns too, with lower priority
  1084. globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
  1085. files = [file for glob in globs for file in path.glob(glob)]
  1086. if not files:
  1087. raise FileNotFoundError(f"Can't find model in directory {path}")
  1088. if len(files) > 1:
  1089. raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}")
  1090. path = files[0]
  1091. paths = find_multifile_paths(path)
  1092. models_plus: list[ModelPlus] = []
  1093. for path in paths:
  1094. print(f"Loading model file {path}")
  1095. models_plus.append(lazy_load_file(path))
  1096. model_plus = merge_multifile_models(models_plus)
  1097. return model_plus
  1098. class VocabFactory:
  1099. _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab]
  1100. def __init__(self, path: Path):
  1101. self.path = path
  1102. def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
  1103. load_merges = vocab.name == "bpe"
  1104. n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
  1105. return gguf.SpecialVocab(
  1106. model_parent_path,
  1107. load_merges=load_merges,
  1108. special_token_types=None, # Predetermined or passed as a parameter
  1109. n_vocab=n_vocab,
  1110. )
  1111. def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
  1112. vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES}
  1113. selected_vocabs: dict[str, type[Vocab]] = {}
  1114. for vtype in vocab_types:
  1115. try:
  1116. selected_vocabs[vtype] = vocab_classes[vtype]
  1117. except KeyError:
  1118. raise ValueError(f"Unsupported vocabulary type {vtype}") from None
  1119. for vtype, cls in selected_vocabs.items():
  1120. try:
  1121. vocab = cls(self.path)
  1122. break
  1123. except FileNotFoundError:
  1124. pass # ignore unavailable tokenizers
  1125. else:
  1126. raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
  1127. print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
  1128. return vocab
  1129. def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
  1130. vocab: BaseVocab
  1131. if vocab_types is None:
  1132. vocab = NoVocab()
  1133. else:
  1134. vocab = self._create_vocab_by_path(vocab_types)
  1135. # FIXME: Respect --vocab-dir?
  1136. special_vocab = self._create_special_vocab(
  1137. vocab,
  1138. model_parent_path,
  1139. )
  1140. return vocab, special_vocab
  1141. def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
  1142. namestr = {
  1143. GGMLFileType.AllF32: "f32",
  1144. GGMLFileType.MostlyF16: "f16",
  1145. GGMLFileType.MostlyQ8_0:"q8_0",
  1146. }[file_type]
  1147. ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
  1148. if ret in model_paths:
  1149. sys.stderr.write(
  1150. f"Error: Default output path ({ret}) would overwrite the input. "
  1151. "Please explicitly specify a path using --outfile.\n")
  1152. sys.exit(1)
  1153. return ret
  1154. def do_dump_model(model_plus: ModelPlus) -> None:
  1155. print(f"model_plus.paths = {model_plus.paths!r}")
  1156. print(f"model_plus.format = {model_plus.format!r}")
  1157. print(f"model_plus.vocab = {model_plus.vocab!r}")
  1158. for name, lazy_tensor in model_plus.model.items():
  1159. print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
  1160. def main(args_in: list[str] | None = None) -> None:
  1161. output_choices = ["f32", "f16"]
  1162. if np.uint32(1) == np.uint32(1).newbyteorder("<"):
  1163. # We currently only support Q8_0 output on little endian systems.
  1164. output_choices.append("q8_0")
  1165. parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file")
  1166. parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
  1167. parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
  1168. parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
  1169. parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab")
  1170. parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
  1171. parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
  1172. parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
  1173. parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
  1174. parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
  1175. parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
  1176. parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY)
  1177. parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
  1178. parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
  1179. parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
  1180. args = parser.parse_args(args_in)
  1181. if args.no_vocab and args.vocab_only:
  1182. raise ValueError("--vocab-only does not make sense with --no-vocab")
  1183. if args.dump_single:
  1184. model_plus = lazy_load_file(args.model)
  1185. do_dump_model(model_plus)
  1186. return
  1187. if not args.vocab_only:
  1188. model_plus = load_some_model(args.model)
  1189. else:
  1190. model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
  1191. if args.dump:
  1192. do_dump_model(model_plus)
  1193. return
  1194. endianess = gguf.GGUFEndian.LITTLE
  1195. if args.big_endian:
  1196. endianess = gguf.GGUFEndian.BIG
  1197. params = Params.load(model_plus)
  1198. if params.n_ctx == -1:
  1199. if args.ctx is None:
  1200. msg = """\
  1201. The model doesn't have a context size, and you didn't specify one with --ctx
  1202. Please specify one with --ctx:
  1203. - LLaMA v1: --ctx 2048
  1204. - LLaMA v2: --ctx 4096"""
  1205. parser.error(textwrap.dedent(msg))
  1206. params.n_ctx = args.ctx
  1207. if args.outtype:
  1208. params.ftype = {
  1209. "f32": GGMLFileType.AllF32,
  1210. "f16": GGMLFileType.MostlyF16,
  1211. "q8_0": GGMLFileType.MostlyQ8_0,
  1212. }[args.outtype]
  1213. print(f"params = {params}")
  1214. model_parent_path = model_plus.paths[0].parent
  1215. vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
  1216. vocab_factory = VocabFactory(vocab_path)
  1217. vocab_types = None if args.no_vocab else args.vocab_type.split(",")
  1218. vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
  1219. if args.vocab_only:
  1220. assert isinstance(vocab, Vocab)
  1221. if not args.outfile:
  1222. raise ValueError("need --outfile if using --vocab-only")
  1223. outfile = args.outfile
  1224. OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
  1225. endianess=endianess, pad_vocab=args.pad_vocab)
  1226. print(f"Wrote {outfile}")
  1227. return
  1228. if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
  1229. vocab = model_plus.vocab
  1230. print(f"Vocab info: {vocab}")
  1231. print(f"Special vocab info: {special_vocab}")
  1232. model = model_plus.model
  1233. model = convert_model_names(model, params, args.skip_unknown)
  1234. ftype = pick_output_type(model, args.outtype)
  1235. model = convert_to_output_type(model, ftype)
  1236. outfile = args.outfile or default_outfile(model_plus.paths, ftype)
  1237. params.ftype = ftype
  1238. print(f"Writing {outfile}, format {ftype}")
  1239. OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
  1240. concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
  1241. print(f"Wrote {outfile}")
  1242. if __name__ == '__main__':
  1243. main()