1
0

convert_lora_to_gguf.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. from dataclasses import dataclass
  5. import logging
  6. import argparse
  7. import os
  8. import sys
  9. import json
  10. from math import prod
  11. from pathlib import Path
  12. from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
  13. from transformers import AutoConfig, AutoTokenizer
  14. import torch
  15. if TYPE_CHECKING:
  16. from torch import Tensor
  17. if 'NO_LOCAL_GGUF' not in os.environ:
  18. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  19. import gguf
  20. # reuse model definitions from convert_hf_to_gguf.py
  21. from convert_hf_to_gguf import LazyTorchTensor, ModelBase
  22. from gguf.constants import GGUFValueType
  23. logger = logging.getLogger("lora-to-gguf")
  24. @dataclass
  25. class PartialLoraTensor:
  26. A: Tensor | None = None
  27. B: Tensor | None = None
  28. # magic to support tensor shape modifications and splitting
  29. class LoraTorchTensor:
  30. _lora_A: Tensor # (n_rank, row_size)
  31. _lora_B: Tensor # (col_size, n_rank)
  32. _rank: int
  33. def __init__(self, A: Tensor, B: Tensor):
  34. assert len(A.shape) == len(B.shape)
  35. assert A.shape[-2] == B.shape[-1]
  36. if A.dtype != B.dtype:
  37. A = A.to(torch.float32)
  38. B = B.to(torch.float32)
  39. self._lora_A = A
  40. self._lora_B = B
  41. self._rank = B.shape[-1]
  42. def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
  43. return (self._lora_A, self._lora_B)
  44. def __getitem__(
  45. self,
  46. indices: (
  47. SupportsIndex
  48. | slice
  49. | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
  50. ),
  51. ) -> LoraTorchTensor:
  52. shape = self.shape
  53. if isinstance(indices, SupportsIndex):
  54. if len(shape) > 2:
  55. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  56. else:
  57. raise NotImplementedError # can't return a vector
  58. elif isinstance(indices, slice):
  59. if len(shape) > 2:
  60. return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
  61. else:
  62. return LoraTorchTensor(self._lora_A, self._lora_B[indices])
  63. elif isinstance(indices, tuple):
  64. assert len(indices) > 0
  65. if indices[-1] is Ellipsis:
  66. return self[indices[:-1]]
  67. # expand ellipsis
  68. indices = tuple(
  69. u
  70. for v in (
  71. (
  72. (slice(None, None) for _ in range(len(indices) - 1))
  73. if i is Ellipsis
  74. else (i,)
  75. )
  76. for i in indices
  77. )
  78. for u in v
  79. )
  80. if len(indices) < len(shape):
  81. indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
  82. # TODO: make sure this is correct
  83. indices_A = (
  84. *(
  85. (
  86. j.__index__() % self._lora_A.shape[i]
  87. if isinstance(j, SupportsIndex)
  88. else slice(None, None)
  89. )
  90. for i, j in enumerate(indices[:-2])
  91. ),
  92. slice(None, None),
  93. indices[-1],
  94. )
  95. indices_B = indices[:-1]
  96. return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
  97. else:
  98. raise NotImplementedError # unknown indice type
  99. @property
  100. def dtype(self) -> torch.dtype:
  101. assert self._lora_A.dtype == self._lora_B.dtype
  102. return self._lora_A.dtype
  103. @property
  104. def shape(self) -> tuple[int, ...]:
  105. assert len(self._lora_A.shape) == len(self._lora_B.shape)
  106. return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
  107. def size(self, dim=None):
  108. assert dim is None
  109. return self.shape
  110. def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
  111. if isinstance(shape[0], tuple):
  112. new_shape: tuple[int, ...] = shape[0]
  113. else:
  114. new_shape = cast(tuple[int, ...], shape)
  115. orig_shape = self.shape
  116. if len(new_shape) < 2:
  117. raise NotImplementedError # can't become a vector
  118. # expand -1 in the shape
  119. if any(dim == -1 for dim in new_shape):
  120. n_elems = prod(orig_shape)
  121. n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
  122. assert n_elems % n_new_elems == 0
  123. new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
  124. if new_shape[-1] != orig_shape[-1]:
  125. raise NotImplementedError # can't reshape the row size trivially
  126. shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
  127. shape_B = (*new_shape[:-1], self._rank)
  128. return LoraTorchTensor(
  129. self._lora_A.reshape(shape_A),
  130. self._lora_B.reshape(shape_B),
  131. )
  132. def reshape_as(self, other: Tensor) -> LoraTorchTensor:
  133. return self.reshape(*other.shape)
  134. def view(self, *size: int) -> LoraTorchTensor:
  135. return self.reshape(*size)
  136. def permute(self, *dims: int) -> LoraTorchTensor:
  137. shape = self.shape
  138. dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
  139. if dims[-1] == -1:
  140. # TODO: support higher dimensional A shapes bigger than 1
  141. assert all(dim == 1 for dim in self._lora_A.shape[:-2])
  142. return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
  143. if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
  144. return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
  145. else:
  146. # TODO: compose the above two
  147. raise NotImplementedError
  148. def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
  149. shape = self.shape
  150. dims = [i for i in range(len(shape))]
  151. dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
  152. return self.permute(*dims)
  153. def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
  154. return self.transpose(axis0, axis1)
  155. def to(self, *args, **kwargs):
  156. return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
  157. @classmethod
  158. def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
  159. del types # unused
  160. if kwargs is None:
  161. kwargs = {}
  162. if func is torch.permute:
  163. return type(args[0]).permute(*args, **kwargs)
  164. elif func is torch.reshape:
  165. return type(args[0]).reshape(*args, **kwargs)
  166. elif func is torch.stack:
  167. assert isinstance(args[0], Sequence)
  168. dim = kwargs.get("dim", 0)
  169. assert dim == 0
  170. return LoraTorchTensor(
  171. torch.stack([a._lora_A for a in args[0]], dim),
  172. torch.stack([b._lora_B for b in args[0]], dim),
  173. )
  174. elif func is torch.cat:
  175. assert isinstance(args[0], Sequence)
  176. dim = kwargs.get("dim", 0)
  177. assert dim == 0
  178. if len(args[0][0].shape) > 2:
  179. return LoraTorchTensor(
  180. torch.cat([a._lora_A for a in args[0]], dim),
  181. torch.cat([b._lora_B for b in args[0]], dim),
  182. )
  183. elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
  184. return LoraTorchTensor(
  185. args[0][0]._lora_A,
  186. torch.cat([b._lora_B for b in args[0]], dim),
  187. )
  188. else:
  189. raise NotImplementedError
  190. else:
  191. raise NotImplementedError
  192. def get_base_tensor_name(lora_tensor_name: str) -> str:
  193. base_name = lora_tensor_name.replace("base_model.model.", "")
  194. base_name = base_name.replace(".lora_A.weight", ".weight")
  195. base_name = base_name.replace(".lora_B.weight", ".weight")
  196. # models produced by mergekit-extract-lora have token embeddings in the adapter
  197. base_name = base_name.replace(".lora_embedding_A", ".weight")
  198. base_name = base_name.replace(".lora_embedding_B", ".weight")
  199. return base_name
  200. def parse_args() -> argparse.Namespace:
  201. parser = argparse.ArgumentParser(
  202. description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file")
  203. parser.add_argument(
  204. "--outfile", type=Path,
  205. help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
  206. )
  207. parser.add_argument(
  208. "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
  209. help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
  210. )
  211. parser.add_argument(
  212. "--bigendian", action="store_true",
  213. help="model is executed on big endian machine",
  214. )
  215. parser.add_argument(
  216. "--no-lazy", action="store_true",
  217. help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
  218. )
  219. parser.add_argument(
  220. "--verbose", action="store_true",
  221. help="increase output verbosity",
  222. )
  223. parser.add_argument(
  224. "--dry-run", action="store_true",
  225. help="only print out what will be done, without writing any new files",
  226. )
  227. parser.add_argument(
  228. "--base", type=Path,
  229. help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
  230. )
  231. parser.add_argument(
  232. "--base-model-id", type=str,
  233. help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
  234. )
  235. parser.add_argument(
  236. "lora_path", type=Path,
  237. help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
  238. )
  239. return parser.parse_args()
  240. def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
  241. # normally, adapter does not come with base model config, we need to load it from AutoConfig
  242. config = AutoConfig.from_pretrained(hf_model_id)
  243. return config.to_dict()
  244. if __name__ == '__main__':
  245. args = parse_args()
  246. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  247. ftype_map: dict[str, gguf.LlamaFileType] = {
  248. "f32": gguf.LlamaFileType.ALL_F32,
  249. "f16": gguf.LlamaFileType.MOSTLY_F16,
  250. "bf16": gguf.LlamaFileType.MOSTLY_BF16,
  251. "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
  252. "auto": gguf.LlamaFileType.GUESSED,
  253. }
  254. ftype = ftype_map[args.outtype]
  255. dir_base_model: Path | None = args.base
  256. dir_lora: Path = args.lora_path
  257. base_model_id: str | None = args.base_model_id
  258. lora_config = dir_lora / "adapter_config.json"
  259. input_model = dir_lora / "adapter_model.safetensors"
  260. if args.outfile is not None:
  261. fname_out = args.outfile
  262. else:
  263. # output in the same directory as the model by default
  264. fname_out = dir_lora
  265. if os.path.exists(input_model):
  266. # lazy import load_file only if lora is in safetensors format.
  267. from safetensors.torch import load_file
  268. lora_model = load_file(input_model, device="cpu")
  269. else:
  270. input_model = os.path.join(dir_lora, "adapter_model.bin")
  271. lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
  272. # load LoRA config
  273. with open(lora_config, "r") as f:
  274. lparams: dict[str, Any] = json.load(f)
  275. # load base model
  276. if base_model_id is not None:
  277. logger.info(f"Loading base model from Hugging Face: {base_model_id}")
  278. hparams = load_hparams_from_hf(base_model_id)
  279. elif dir_base_model is None:
  280. if "base_model_name_or_path" in lparams:
  281. model_id = lparams["base_model_name_or_path"]
  282. logger.info(f"Loading base model from Hugging Face: {model_id}")
  283. try:
  284. hparams = load_hparams_from_hf(model_id)
  285. except OSError as e:
  286. logger.error(f"Failed to load base model config: {e}")
  287. logger.error("Please try downloading the base model and add its path to --base")
  288. sys.exit(1)
  289. else:
  290. logger.error("'base_model_name_or_path' is not found in adapter_config.json")
  291. logger.error("Base model config is required. Please download the base model and add its path to --base")
  292. sys.exit(1)
  293. else:
  294. logger.info(f"Loading base model: {dir_base_model.name}")
  295. hparams = ModelBase.load_hparams(dir_base_model, False)
  296. with torch.inference_mode():
  297. try:
  298. model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
  299. except NotImplementedError:
  300. logger.error(f"Model {hparams['architectures'][0]} is not supported")
  301. sys.exit(1)
  302. class LoraModel(model_class):
  303. model_arch = model_class.model_arch
  304. lora_alpha: float
  305. def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
  306. super().__init__(*args, **kwargs)
  307. self.dir_model_card = dir_lora_model
  308. self.lora_alpha = float(lora_alpha)
  309. def set_vocab(self):
  310. pass
  311. def set_type(self):
  312. self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
  313. self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
  314. def set_gguf_parameters(self):
  315. logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
  316. self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
  317. alora_invocation_tokens = lparams.get("alora_invocation_tokens")
  318. invocation_string = lparams.get("invocation_string")
  319. if invocation_string and not alora_invocation_tokens:
  320. logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
  321. base_model_path_or_id = hparams.get("_name_or_path")
  322. try:
  323. tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
  324. except ValueError:
  325. logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
  326. raise
  327. # NOTE: There's an off-by-one with the older aLoRAs where
  328. # the invocation string includes the "<|start_of_turn|>"
  329. # token, but the adapters themselves were trained to
  330. # activate _after_ that first token, so we drop it here.
  331. alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
  332. if alora_invocation_tokens:
  333. logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
  334. self.gguf_writer.add_key_value(
  335. gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
  336. alora_invocation_tokens,
  337. GGUFValueType.ARRAY,
  338. GGUFValueType.UINT32,
  339. )
  340. def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
  341. # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
  342. return ()
  343. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  344. tensor_map: dict[str, PartialLoraTensor] = {}
  345. for name, tensor in lora_model.items():
  346. if self.lazy:
  347. tensor = LazyTorchTensor.from_eager(tensor)
  348. base_name = get_base_tensor_name(name)
  349. # note: mergekit-extract-lora also adds token embeddings to the adapter
  350. is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
  351. is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
  352. if not is_lora_a and not is_lora_b:
  353. if ".base_layer.weight" in name:
  354. continue
  355. # mergekit-extract-lora add these layernorm to the adapter, we need to keep them
  356. if "_layernorm" in name or ".norm" in name:
  357. yield (base_name, tensor)
  358. continue
  359. logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
  360. if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
  361. logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
  362. logger.error("Please refer to https://github.com/ggml-org/llama.cpp/pull/9948")
  363. sys.exit(1)
  364. if base_name in tensor_map:
  365. if is_lora_a:
  366. tensor_map[base_name].A = tensor
  367. else:
  368. tensor_map[base_name].B = tensor
  369. else:
  370. if is_lora_a:
  371. tensor_map[base_name] = PartialLoraTensor(A=tensor)
  372. else:
  373. tensor_map[base_name] = PartialLoraTensor(B=tensor)
  374. for name, tensor in tensor_map.items():
  375. assert tensor.A is not None
  376. assert tensor.B is not None
  377. yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
  378. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  379. dest = list(super().modify_tensors(data_torch, name, bid))
  380. # some archs may have the same tensor for lm_head and output (tie word embeddings)
  381. # in this case, adapters targeting lm_head will fail when using llama-export-lora
  382. # therefore, we ignore them for now
  383. # see: https://github.com/ggml-org/llama.cpp/issues/9065
  384. if name == "lm_head.weight" and len(dest) == 0:
  385. raise ValueError("lm_head is present in adapter, but is ignored in base model")
  386. for dest_name, dest_data in dest:
  387. # mergekit-extract-lora add these layernorm to the adapter
  388. if "_norm" in dest_name:
  389. assert dest_data.dim() == 1
  390. yield (dest_name, dest_data)
  391. continue
  392. # otherwise, we must get the lora_A and lora_B tensors
  393. assert isinstance(dest_data, LoraTorchTensor)
  394. lora_a, lora_b = dest_data.get_lora_A_B()
  395. # note: mergekit-extract-lora flip and transpose A and B
  396. # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
  397. if "token_embd.weight" in dest_name:
  398. lora_a = lora_a.T
  399. yield (dest_name + ".lora_a", lora_a)
  400. yield (dest_name + ".lora_b", lora_b)
  401. alpha: float = lparams["lora_alpha"]
  402. model_instance = LoraModel(
  403. dir_base_model,
  404. ftype,
  405. fname_out,
  406. is_big_endian=args.bigendian,
  407. use_temp_file=False,
  408. eager=args.no_lazy,
  409. dry_run=args.dry_run,
  410. dir_lora_model=dir_lora,
  411. lora_alpha=alpha,
  412. hparams=hparams,
  413. )
  414. logger.info("Exporting model...")
  415. model_instance.write()
  416. logger.info(f"Model successfully exported to {model_instance.fname_out}")