convert_lora_to_gguf.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. from dataclasses import dataclass
  5. import logging
  6. import argparse
  7. import os
  8. import sys
  9. import json
  10. from math import prod
  11. from pathlib import Path
  12. from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
  13. import torch
  14. if TYPE_CHECKING:
  15. from torch import Tensor
  16. if 'NO_LOCAL_GGUF' not in os.environ:
  17. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  18. import gguf
  19. # reuse model definitions from convert_hf_to_gguf.py
  20. from convert_hf_to_gguf import LazyTorchTensor, Model
  21. logger = logging.getLogger("lora-to-gguf")
  22. @dataclass
  23. class PartialLoraTensor:
  24. A: Tensor | None = None
  25. B: Tensor | None = None
  26. # magic to support tensor shape modifications and splitting
  27. class LoraTorchTensor:
  28. _lora_A: Tensor # (n_rank, row_size)
  29. _lora_B: Tensor # (col_size, n_rank)
  30. _rank: int
  31. def __init__(self, A: Tensor, B: Tensor):
  32. assert len(A.shape) == len(B.shape)
  33. assert A.shape[-2] == B.shape[-1]
  34. if A.dtype != B.dtype:
  35. A = A.to(torch.float32)
  36. B = B.to(torch.float32)
  37. self._lora_A = A
  38. self._lora_B = B
  39. self._rank = B.shape[-1]
  40. def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
  41. return (self._lora_A, self._lora_B)
  42. def __getitem__(
  43. self,
  44. indices: (
  45. SupportsIndex
  46. | slice
  47. | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
  48. ),
  49. ) -> LoraTorchTensor:
  50. shape = self.shape
  51. if isinstance(indices, SupportsIndex):
  52. if len(shape) > 2:
  53. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  54. else:
  55. raise NotImplementedError # can't return a vector
  56. elif isinstance(indices, slice):
  57. if len(shape) > 2:
  58. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  59. else:
  60. return LoraTorchTensor(self._lora_A, self._lora_B[indices])
  61. elif isinstance(indices, tuple):
  62. assert len(indices) > 0
  63. if indices[-1] is Ellipsis:
  64. return self[indices[:-1]]
  65. # expand ellipsis
  66. indices = tuple(
  67. u
  68. for v in (
  69. (
  70. (slice(None, None) for _ in range(len(indices) - 1))
  71. if i is Ellipsis
  72. else (i,)
  73. )
  74. for i in indices
  75. )
  76. for u in v
  77. )
  78. if len(indices) < len(shape):
  79. indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
  80. # TODO: make sure this is correct
  81. indices_A = (
  82. *(
  83. (
  84. j.__index__() % self._lora_A.shape[i]
  85. if isinstance(j, SupportsIndex)
  86. else slice(None, None)
  87. )
  88. for i, j in enumerate(indices[:-2])
  89. ),
  90. slice(None, None),
  91. indices[-1],
  92. )
  93. indices_B = indices[:-1]
  94. return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
  95. else:
  96. raise NotImplementedError # unknown indice type
  97. @property
  98. def dtype(self) -> torch.dtype:
  99. assert self._lora_A.dtype == self._lora_B.dtype
  100. return self._lora_A.dtype
  101. @property
  102. def shape(self) -> tuple[int, ...]:
  103. assert len(self._lora_A.shape) == len(self._lora_B.shape)
  104. return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
  105. def size(self, dim=None):
  106. assert dim is None
  107. return self.shape
  108. def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
  109. if isinstance(shape[0], tuple):
  110. new_shape: tuple[int, ...] = shape[0]
  111. else:
  112. new_shape = cast(tuple[int, ...], shape)
  113. orig_shape = self.shape
  114. if len(new_shape) < 2:
  115. raise NotImplementedError # can't become a vector
  116. # expand -1 in the shape
  117. if any(dim == -1 for dim in new_shape):
  118. n_elems = prod(orig_shape)
  119. n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
  120. assert n_elems % n_new_elems == 0
  121. new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
  122. if new_shape[-1] != orig_shape[-1]:
  123. raise NotImplementedError # can't reshape the row size trivially
  124. shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
  125. shape_B = (*new_shape[:-1], self._rank)
  126. return LoraTorchTensor(
  127. self._lora_A.reshape(shape_A),
  128. self._lora_B.reshape(shape_B),
  129. )
  130. def reshape_as(self, other: Tensor) -> LoraTorchTensor:
  131. return self.reshape(*other.shape)
  132. def view(self, *size: int) -> LoraTorchTensor:
  133. return self.reshape(*size)
  134. def permute(self, *dims: int) -> LoraTorchTensor:
  135. shape = self.shape
  136. dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
  137. if dims[-1] == -1:
  138. # TODO: support higher dimensional A shapes bigger than 1
  139. assert all(dim == 1 for dim in self._lora_A.shape[:-2])
  140. return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
  141. if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
  142. return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
  143. else:
  144. # TODO: compose the above two
  145. raise NotImplementedError
  146. def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
  147. shape = self.shape
  148. dims = [i for i in range(len(shape))]
  149. dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
  150. return self.permute(*dims)
  151. def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
  152. return self.transpose(axis0, axis1)
  153. def to(self, *args, **kwargs):
  154. return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
  155. @classmethod
  156. def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
  157. del types # unused
  158. if kwargs is None:
  159. kwargs = {}
  160. if func is torch.permute:
  161. return type(args[0]).permute(*args, **kwargs)
  162. elif func is torch.reshape:
  163. return type(args[0]).reshape(*args, **kwargs)
  164. elif func is torch.stack:
  165. assert isinstance(args[0], Sequence)
  166. dim = kwargs.get("dim", 0)
  167. assert dim == 0
  168. return LoraTorchTensor(
  169. torch.stack([a._lora_A for a in args[0]], dim),
  170. torch.stack([b._lora_B for b in args[0]], dim),
  171. )
  172. elif func is torch.cat:
  173. assert isinstance(args[0], Sequence)
  174. dim = kwargs.get("dim", 0)
  175. assert dim == 0
  176. if len(args[0][0].shape) > 2:
  177. return LoraTorchTensor(
  178. torch.cat([a._lora_A for a in args[0]], dim),
  179. torch.cat([b._lora_B for b in args[0]], dim),
  180. )
  181. elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
  182. return LoraTorchTensor(
  183. args[0][0]._lora_A,
  184. torch.cat([b._lora_B for b in args[0]], dim),
  185. )
  186. else:
  187. raise NotImplementedError
  188. else:
  189. raise NotImplementedError
  190. def get_base_tensor_name(lora_tensor_name: str) -> str:
  191. base_name = lora_tensor_name.replace("base_model.model.", "")
  192. base_name = base_name.replace(".lora_A.weight", ".weight")
  193. base_name = base_name.replace(".lora_B.weight", ".weight")
  194. return base_name
  195. def parse_args() -> argparse.Namespace:
  196. parser = argparse.ArgumentParser(
  197. description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
  198. parser.add_argument(
  199. "--outfile", type=Path,
  200. help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
  201. )
  202. parser.add_argument(
  203. "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
  204. help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
  205. )
  206. parser.add_argument(
  207. "--bigendian", action="store_true",
  208. help="model is executed on big endian machine",
  209. )
  210. parser.add_argument(
  211. "--no-lazy", action="store_true",
  212. help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
  213. )
  214. parser.add_argument(
  215. "--verbose", action="store_true",
  216. help="increase output verbosity",
  217. )
  218. parser.add_argument(
  219. "--dry-run", action="store_true",
  220. help="only print out what will be done, without writing any new files",
  221. )
  222. parser.add_argument(
  223. "--base", type=Path, required=True,
  224. help="directory containing base model file",
  225. )
  226. parser.add_argument(
  227. "lora_path", type=Path,
  228. help="directory containing LoRA adapter file",
  229. )
  230. return parser.parse_args()
  231. if __name__ == '__main__':
  232. args = parse_args()
  233. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  234. ftype_map: dict[str, gguf.LlamaFileType] = {
  235. "f32": gguf.LlamaFileType.ALL_F32,
  236. "f16": gguf.LlamaFileType.MOSTLY_F16,
  237. "bf16": gguf.LlamaFileType.MOSTLY_BF16,
  238. "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
  239. "auto": gguf.LlamaFileType.GUESSED,
  240. }
  241. ftype = ftype_map[args.outtype]
  242. dir_base_model: Path = args.base
  243. dir_lora: Path = args.lora_path
  244. lora_config = dir_lora / "adapter_config.json"
  245. input_model = dir_lora / "adapter_model.safetensors"
  246. if args.outfile is not None:
  247. fname_out = args.outfile
  248. else:
  249. # output in the same directory as the model by default
  250. fname_out = dir_lora
  251. if os.path.exists(input_model):
  252. # lazy import load_file only if lora is in safetensors format.
  253. from safetensors.torch import load_file
  254. lora_model = load_file(input_model, device="cpu")
  255. else:
  256. input_model = os.path.join(dir_lora, "adapter_model.bin")
  257. lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
  258. # load base model
  259. logger.info(f"Loading base model: {dir_base_model.name}")
  260. hparams = Model.load_hparams(dir_base_model)
  261. with torch.inference_mode():
  262. try:
  263. model_class = Model.from_model_architecture(hparams["architectures"][0])
  264. except NotImplementedError:
  265. logger.error(f"Model {hparams['architectures'][0]} is not supported")
  266. sys.exit(1)
  267. class LoraModel(model_class):
  268. model_arch = model_class.model_arch
  269. lora_alpha: float
  270. def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
  271. super().__init__(*args, **kwargs)
  272. self.dir_model_card = dir_lora_model
  273. self.lora_alpha = float(lora_alpha)
  274. def set_type(self):
  275. self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
  276. self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
  277. def set_gguf_parameters(self):
  278. self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
  279. super().set_gguf_parameters()
  280. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  281. tensor_map: dict[str, PartialLoraTensor] = {}
  282. for name, tensor in lora_model.items():
  283. if self.lazy:
  284. tensor = LazyTorchTensor.from_eager(tensor)
  285. base_name = get_base_tensor_name(name)
  286. is_lora_a = ".lora_A.weight" in name
  287. is_lora_b = ".lora_B.weight" in name
  288. if not is_lora_a and not is_lora_b:
  289. if ".base_layer.weight" in name:
  290. continue
  291. logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
  292. sys.exit(1)
  293. if base_name in tensor_map:
  294. if is_lora_a:
  295. tensor_map[base_name].A = tensor
  296. else:
  297. tensor_map[base_name].B = tensor
  298. else:
  299. if is_lora_a:
  300. tensor_map[base_name] = PartialLoraTensor(A=tensor)
  301. else:
  302. tensor_map[base_name] = PartialLoraTensor(B=tensor)
  303. for name, tensor in tensor_map.items():
  304. assert tensor.A is not None
  305. assert tensor.B is not None
  306. yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
  307. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  308. dest = super().modify_tensors(data_torch, name, bid)
  309. for dest_name, dest_data in dest:
  310. assert isinstance(dest_data, LoraTorchTensor)
  311. lora_a, lora_b = dest_data.get_lora_A_B()
  312. yield (dest_name + ".lora_a", lora_a)
  313. yield (dest_name + ".lora_b", lora_b)
  314. with open(lora_config, "r") as f:
  315. lparams: dict[str, Any] = json.load(f)
  316. alpha: float = lparams["lora_alpha"]
  317. model_instance = LoraModel(
  318. dir_base_model,
  319. ftype,
  320. fname_out,
  321. is_big_endian=args.bigendian,
  322. use_temp_file=False,
  323. eager=args.no_lazy,
  324. dry_run=args.dry_run,
  325. dir_lora_model=dir_lora,
  326. lora_alpha=alpha,
  327. is_lora=True,
  328. )
  329. logger.info("Exporting model...")
  330. model_instance.write()
  331. logger.info(f"Model successfully exported to {model_instance.fname_out}")