convert_lora_to_gguf.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. from dataclasses import dataclass
  5. import logging
  6. import argparse
  7. import os
  8. import sys
  9. import json
  10. from math import prod
  11. from pathlib import Path
  12. from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
  13. from transformers import AutoConfig
  14. import torch
  15. if TYPE_CHECKING:
  16. from torch import Tensor
  17. if 'NO_LOCAL_GGUF' not in os.environ:
  18. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  19. import gguf
  20. # reuse model definitions from convert_hf_to_gguf.py
  21. from convert_hf_to_gguf import LazyTorchTensor, Model
  22. logger = logging.getLogger("lora-to-gguf")
  23. @dataclass
  24. class PartialLoraTensor:
  25. A: Tensor | None = None
  26. B: Tensor | None = None
  27. # magic to support tensor shape modifications and splitting
  28. class LoraTorchTensor:
  29. _lora_A: Tensor # (n_rank, row_size)
  30. _lora_B: Tensor # (col_size, n_rank)
  31. _rank: int
  32. def __init__(self, A: Tensor, B: Tensor):
  33. assert len(A.shape) == len(B.shape)
  34. assert A.shape[-2] == B.shape[-1]
  35. if A.dtype != B.dtype:
  36. A = A.to(torch.float32)
  37. B = B.to(torch.float32)
  38. self._lora_A = A
  39. self._lora_B = B
  40. self._rank = B.shape[-1]
  41. def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
  42. return (self._lora_A, self._lora_B)
  43. def __getitem__(
  44. self,
  45. indices: (
  46. SupportsIndex
  47. | slice
  48. | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
  49. ),
  50. ) -> LoraTorchTensor:
  51. shape = self.shape
  52. if isinstance(indices, SupportsIndex):
  53. if len(shape) > 2:
  54. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  55. else:
  56. raise NotImplementedError # can't return a vector
  57. elif isinstance(indices, slice):
  58. if len(shape) > 2:
  59. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  60. else:
  61. return LoraTorchTensor(self._lora_A, self._lora_B[indices])
  62. elif isinstance(indices, tuple):
  63. assert len(indices) > 0
  64. if indices[-1] is Ellipsis:
  65. return self[indices[:-1]]
  66. # expand ellipsis
  67. indices = tuple(
  68. u
  69. for v in (
  70. (
  71. (slice(None, None) for _ in range(len(indices) - 1))
  72. if i is Ellipsis
  73. else (i,)
  74. )
  75. for i in indices
  76. )
  77. for u in v
  78. )
  79. if len(indices) < len(shape):
  80. indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
  81. # TODO: make sure this is correct
  82. indices_A = (
  83. *(
  84. (
  85. j.__index__() % self._lora_A.shape[i]
  86. if isinstance(j, SupportsIndex)
  87. else slice(None, None)
  88. )
  89. for i, j in enumerate(indices[:-2])
  90. ),
  91. slice(None, None),
  92. indices[-1],
  93. )
  94. indices_B = indices[:-1]
  95. return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
  96. else:
  97. raise NotImplementedError # unknown indice type
  98. @property
  99. def dtype(self) -> torch.dtype:
  100. assert self._lora_A.dtype == self._lora_B.dtype
  101. return self._lora_A.dtype
  102. @property
  103. def shape(self) -> tuple[int, ...]:
  104. assert len(self._lora_A.shape) == len(self._lora_B.shape)
  105. return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
  106. def size(self, dim=None):
  107. assert dim is None
  108. return self.shape
  109. def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
  110. if isinstance(shape[0], tuple):
  111. new_shape: tuple[int, ...] = shape[0]
  112. else:
  113. new_shape = cast(tuple[int, ...], shape)
  114. orig_shape = self.shape
  115. if len(new_shape) < 2:
  116. raise NotImplementedError # can't become a vector
  117. # expand -1 in the shape
  118. if any(dim == -1 for dim in new_shape):
  119. n_elems = prod(orig_shape)
  120. n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
  121. assert n_elems % n_new_elems == 0
  122. new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
  123. if new_shape[-1] != orig_shape[-1]:
  124. raise NotImplementedError # can't reshape the row size trivially
  125. shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
  126. shape_B = (*new_shape[:-1], self._rank)
  127. return LoraTorchTensor(
  128. self._lora_A.reshape(shape_A),
  129. self._lora_B.reshape(shape_B),
  130. )
  131. def reshape_as(self, other: Tensor) -> LoraTorchTensor:
  132. return self.reshape(*other.shape)
  133. def view(self, *size: int) -> LoraTorchTensor:
  134. return self.reshape(*size)
  135. def permute(self, *dims: int) -> LoraTorchTensor:
  136. shape = self.shape
  137. dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
  138. if dims[-1] == -1:
  139. # TODO: support higher dimensional A shapes bigger than 1
  140. assert all(dim == 1 for dim in self._lora_A.shape[:-2])
  141. return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
  142. if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
  143. return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
  144. else:
  145. # TODO: compose the above two
  146. raise NotImplementedError
  147. def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
  148. shape = self.shape
  149. dims = [i for i in range(len(shape))]
  150. dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
  151. return self.permute(*dims)
  152. def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
  153. return self.transpose(axis0, axis1)
  154. def to(self, *args, **kwargs):
  155. return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
  156. @classmethod
  157. def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
  158. del types # unused
  159. if kwargs is None:
  160. kwargs = {}
  161. if func is torch.permute:
  162. return type(args[0]).permute(*args, **kwargs)
  163. elif func is torch.reshape:
  164. return type(args[0]).reshape(*args, **kwargs)
  165. elif func is torch.stack:
  166. assert isinstance(args[0], Sequence)
  167. dim = kwargs.get("dim", 0)
  168. assert dim == 0
  169. return LoraTorchTensor(
  170. torch.stack([a._lora_A for a in args[0]], dim),
  171. torch.stack([b._lora_B for b in args[0]], dim),
  172. )
  173. elif func is torch.cat:
  174. assert isinstance(args[0], Sequence)
  175. dim = kwargs.get("dim", 0)
  176. assert dim == 0
  177. if len(args[0][0].shape) > 2:
  178. return LoraTorchTensor(
  179. torch.cat([a._lora_A for a in args[0]], dim),
  180. torch.cat([b._lora_B for b in args[0]], dim),
  181. )
  182. elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
  183. return LoraTorchTensor(
  184. args[0][0]._lora_A,
  185. torch.cat([b._lora_B for b in args[0]], dim),
  186. )
  187. else:
  188. raise NotImplementedError
  189. else:
  190. raise NotImplementedError
  191. def get_base_tensor_name(lora_tensor_name: str) -> str:
  192. base_name = lora_tensor_name.replace("base_model.model.", "")
  193. base_name = base_name.replace(".lora_A.weight", ".weight")
  194. base_name = base_name.replace(".lora_B.weight", ".weight")
  195. # models produced by mergekit-extract-lora have token embeddings in the adapter
  196. base_name = base_name.replace(".lora_embedding_A", ".weight")
  197. base_name = base_name.replace(".lora_embedding_B", ".weight")
  198. return base_name
  199. def parse_args() -> argparse.Namespace:
  200. parser = argparse.ArgumentParser(
  201. description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file")
  202. parser.add_argument(
  203. "--outfile", type=Path,
  204. help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
  205. )
  206. parser.add_argument(
  207. "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
  208. help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
  209. )
  210. parser.add_argument(
  211. "--bigendian", action="store_true",
  212. help="model is executed on big endian machine",
  213. )
  214. parser.add_argument(
  215. "--no-lazy", action="store_true",
  216. help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
  217. )
  218. parser.add_argument(
  219. "--verbose", action="store_true",
  220. help="increase output verbosity",
  221. )
  222. parser.add_argument(
  223. "--dry-run", action="store_true",
  224. help="only print out what will be done, without writing any new files",
  225. )
  226. parser.add_argument(
  227. "--base", type=Path,
  228. help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
  229. )
  230. parser.add_argument(
  231. "--base-model-id", type=str,
  232. help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
  233. )
  234. parser.add_argument(
  235. "lora_path", type=Path,
  236. help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
  237. )
  238. return parser.parse_args()
  239. def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
  240. # normally, adapter does not come with base model config, we need to load it from AutoConfig
  241. config = AutoConfig.from_pretrained(hf_model_id)
  242. return config.to_dict()
  243. if __name__ == '__main__':
  244. args = parse_args()
  245. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  246. ftype_map: dict[str, gguf.LlamaFileType] = {
  247. "f32": gguf.LlamaFileType.ALL_F32,
  248. "f16": gguf.LlamaFileType.MOSTLY_F16,
  249. "bf16": gguf.LlamaFileType.MOSTLY_BF16,
  250. "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
  251. "auto": gguf.LlamaFileType.GUESSED,
  252. }
  253. ftype = ftype_map[args.outtype]
  254. dir_base_model: Path | None = args.base
  255. dir_lora: Path = args.lora_path
  256. base_model_id: str | None = args.base_model_id
  257. lora_config = dir_lora / "adapter_config.json"
  258. input_model = dir_lora / "adapter_model.safetensors"
  259. if args.outfile is not None:
  260. fname_out = args.outfile
  261. else:
  262. # output in the same directory as the model by default
  263. fname_out = dir_lora
  264. if os.path.exists(input_model):
  265. # lazy import load_file only if lora is in safetensors format.
  266. from safetensors.torch import load_file
  267. lora_model = load_file(input_model, device="cpu")
  268. else:
  269. input_model = os.path.join(dir_lora, "adapter_model.bin")
  270. lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
  271. # load LoRA config
  272. with open(lora_config, "r") as f:
  273. lparams: dict[str, Any] = json.load(f)
  274. # load base model
  275. if base_model_id is not None:
  276. logger.info(f"Loading base model from Hugging Face: {base_model_id}")
  277. hparams = load_hparams_from_hf(base_model_id)
  278. elif dir_base_model is None:
  279. if "base_model_name_or_path" in lparams:
  280. model_id = lparams["base_model_name_or_path"]
  281. logger.info(f"Loading base model from Hugging Face: {model_id}")
  282. try:
  283. hparams = load_hparams_from_hf(model_id)
  284. except OSError as e:
  285. logger.error(f"Failed to load base model config: {e}")
  286. logger.error("Please try downloading the base model and add its path to --base")
  287. sys.exit(1)
  288. else:
  289. logger.error("'base_model_name_or_path' is not found in adapter_config.json")
  290. logger.error("Base model config is required. Please download the base model and add its path to --base")
  291. sys.exit(1)
  292. else:
  293. logger.info(f"Loading base model: {dir_base_model.name}")
  294. hparams = Model.load_hparams(dir_base_model)
  295. with torch.inference_mode():
  296. try:
  297. model_class = Model.from_model_architecture(hparams["architectures"][0])
  298. except NotImplementedError:
  299. logger.error(f"Model {hparams['architectures'][0]} is not supported")
  300. sys.exit(1)
  301. class LoraModel(model_class):
  302. model_arch = model_class.model_arch
  303. lora_alpha: float
  304. def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
  305. super().__init__(*args, **kwargs)
  306. self.dir_model_card = dir_lora_model
  307. self.lora_alpha = float(lora_alpha)
  308. def set_vocab(self):
  309. pass
  310. def set_type(self):
  311. self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
  312. self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
  313. def set_gguf_parameters(self):
  314. self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
  315. def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
  316. # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
  317. return ()
  318. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  319. tensor_map: dict[str, PartialLoraTensor] = {}
  320. for name, tensor in lora_model.items():
  321. if self.lazy:
  322. tensor = LazyTorchTensor.from_eager(tensor)
  323. base_name = get_base_tensor_name(name)
  324. # note: mergekit-extract-lora also adds token embeddings to the adapter
  325. is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
  326. is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
  327. if not is_lora_a and not is_lora_b:
  328. if ".base_layer.weight" in name:
  329. continue
  330. # mergekit-extract-lora add these layernorm to the adapter, we need to keep them
  331. if "_layernorm" in name or ".norm" in name:
  332. yield (base_name, tensor)
  333. continue
  334. logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
  335. if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
  336. logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
  337. logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
  338. sys.exit(1)
  339. if base_name in tensor_map:
  340. if is_lora_a:
  341. tensor_map[base_name].A = tensor
  342. else:
  343. tensor_map[base_name].B = tensor
  344. else:
  345. if is_lora_a:
  346. tensor_map[base_name] = PartialLoraTensor(A=tensor)
  347. else:
  348. tensor_map[base_name] = PartialLoraTensor(B=tensor)
  349. for name, tensor in tensor_map.items():
  350. assert tensor.A is not None
  351. assert tensor.B is not None
  352. yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
  353. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  354. dest = list(super().modify_tensors(data_torch, name, bid))
  355. # some archs may have the same tensor for lm_head and output (tie word embeddings)
  356. # in this case, adapters targeting lm_head will fail when using llama-export-lora
  357. # therefore, we ignore them for now
  358. # see: https://github.com/ggerganov/llama.cpp/issues/9065
  359. if name == "lm_head.weight" and len(dest) == 0:
  360. raise ValueError("lm_head is present in adapter, but is ignored in base model")
  361. for dest_name, dest_data in dest:
  362. # mergekit-extract-lora add these layernorm to the adapter
  363. if "_norm" in dest_name:
  364. assert dest_data.dim() == 1
  365. yield (dest_name, dest_data)
  366. continue
  367. # otherwise, we must get the lora_A and lora_B tensors
  368. assert isinstance(dest_data, LoraTorchTensor)
  369. lora_a, lora_b = dest_data.get_lora_A_B()
  370. # note: mergekit-extract-lora flip and transpose A and B
  371. # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
  372. if "token_embd.weight" in dest_name:
  373. lora_a = lora_a.T
  374. yield (dest_name + ".lora_a", lora_a)
  375. yield (dest_name + ".lora_b", lora_b)
  376. alpha: float = lparams["lora_alpha"]
  377. model_instance = LoraModel(
  378. dir_base_model,
  379. ftype,
  380. fname_out,
  381. is_big_endian=args.bigendian,
  382. use_temp_file=False,
  383. eager=args.no_lazy,
  384. dry_run=args.dry_run,
  385. dir_lora_model=dir_lora,
  386. lora_alpha=alpha,
  387. hparams=hparams,
  388. )
  389. logger.info("Exporting model...")
  390. model_instance.write()
  391. logger.info(f"Model successfully exported to {model_instance.fname_out}")