utility.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import Literal
  5. import os
  6. import json
  7. import numpy as np
  8. def fill_templated_filename(filename: str, output_type: str | None) -> str:
  9. # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
  10. ftype_lowercase: str = output_type.lower() if output_type is not None else ""
  11. ftype_uppercase: str = output_type.upper() if output_type is not None else ""
  12. return filename.format(ftype_lowercase,
  13. outtype=ftype_lowercase, ftype=ftype_lowercase,
  14. OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
  15. def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
  16. if model_params_count > 1e12 :
  17. # Trillions Of Parameters
  18. scaled_model_params = model_params_count * 1e-12
  19. scale_suffix = "T"
  20. elif model_params_count > 1e9 :
  21. # Billions Of Parameters
  22. scaled_model_params = model_params_count * 1e-9
  23. scale_suffix = "B"
  24. elif model_params_count > 1e6 :
  25. # Millions Of Parameters
  26. scaled_model_params = model_params_count * 1e-6
  27. scale_suffix = "M"
  28. else:
  29. # Thousands Of Parameters
  30. scaled_model_params = model_params_count * 1e-3
  31. scale_suffix = "K"
  32. fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
  33. return f"{scaled_model_params:.{fix}f}{scale_suffix}"
  34. def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
  35. if expert_count > 0:
  36. pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
  37. size_class = f"{expert_count}x{pretty_size}"
  38. else:
  39. size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
  40. return size_class
  41. def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
  42. # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
  43. if base_name is not None:
  44. name = base_name.strip().replace(' ', '-').replace('/', '-')
  45. elif model_name is not None:
  46. name = model_name.strip().replace(' ', '-').replace('/', '-')
  47. else:
  48. name = "ggml-model"
  49. parameters = f"-{size_label}" if size_label is not None else ""
  50. finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
  51. version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
  52. encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
  53. kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
  54. return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
  55. @dataclass
  56. class RemoteTensor:
  57. dtype: str
  58. shape: tuple[int, ...]
  59. offset_start: int
  60. size: int
  61. url: str
  62. def data(self) -> bytearray:
  63. # TODO: handle request errors (maybe with limited retries?)
  64. # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
  65. data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
  66. return data
  67. class SafetensorRemote:
  68. """
  69. Uility class to handle remote safetensor files.
  70. This class is designed to work with Hugging Face model repositories.
  71. Example (one model has single safetensor file, the other has multiple):
  72. for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
  73. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  74. print(tensors)
  75. Example reading tensor data:
  76. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  77. for name, meta in tensors.items():
  78. dtype, shape, offset_start, size, remote_safetensor_url = meta
  79. # read the tensor data
  80. data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
  81. print(data)
  82. """
  83. BASE_DOMAIN = "https://huggingface.co"
  84. ALIGNMENT = 8 # bytes
  85. @classmethod
  86. def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
  87. """
  88. Get list of tensors from a Hugging Face model repository.
  89. Returns a dictionary of tensor names and their metadata.
  90. Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
  91. """
  92. # case 1: model has only one single model.safetensor file
  93. is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
  94. if is_single_file:
  95. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
  96. return cls.get_list_tensors(url)
  97. # case 2: model has multiple files
  98. index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
  99. is_multiple_files = cls.check_file_exist(index_url)
  100. if is_multiple_files:
  101. # read the index file
  102. index_data = cls.get_data_by_range(index_url, 0)
  103. index_str = index_data.decode('utf-8')
  104. index_json = json.loads(index_str)
  105. assert index_json.get("weight_map") is not None, "weight_map not found in index file"
  106. weight_map = index_json["weight_map"]
  107. # get the list of files
  108. all_files = list(set(weight_map.values()))
  109. all_files.sort() # make sure we load shard files in order
  110. # get the list of tensors
  111. tensors: dict[str, RemoteTensor] = {}
  112. for file in all_files:
  113. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
  114. for key, val in cls.get_list_tensors(url).items():
  115. tensors[key] = val
  116. return tensors
  117. raise ValueError(
  118. f"No safetensor file has been found for model {model_id}."
  119. "If the repo has safetensor files, make sure the model is public or you have a "
  120. "valid Hugging Face token set in the environment variable HF_TOKEN."
  121. )
  122. @classmethod
  123. def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
  124. """
  125. Get list of tensors from a remote safetensor file.
  126. Returns a dictionary of tensor names and their metadata.
  127. Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
  128. """
  129. metadata, data_start_offset = cls.get_metadata(url)
  130. res: dict[str, RemoteTensor] = {}
  131. for name, meta in metadata.items():
  132. if name == "__metadata__":
  133. continue
  134. if not isinstance(meta, dict):
  135. raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
  136. try:
  137. dtype = meta["dtype"]
  138. shape = meta["shape"]
  139. offset_start_relative, offset_end_relative = meta["data_offsets"]
  140. size = offset_end_relative - offset_start_relative
  141. offset_start = data_start_offset + offset_start_relative
  142. res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
  143. except KeyError as e:
  144. raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
  145. # order by name (same as default safetensors behavior)
  146. # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
  147. res = dict(sorted(res.items(), key=lambda t: t[0]))
  148. return res
  149. @classmethod
  150. def get_metadata(cls, url: str) -> tuple[dict, int]:
  151. """
  152. Get JSON metadata from a remote safetensor file.
  153. Returns tuple of (metadata, data_start_offset)
  154. """
  155. # Request first 5MB of the file (hopefully enough for metadata)
  156. read_size = 5 * 1024 * 1024
  157. raw_data = cls.get_data_by_range(url, 0, read_size)
  158. # Parse header
  159. # First 8 bytes contain the metadata length as u64 little-endian
  160. if len(raw_data) < 8:
  161. raise ValueError("Not enough data to read metadata size")
  162. metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
  163. # Calculate the data start offset
  164. data_start_offset = 8 + metadata_length
  165. alignment = SafetensorRemote.ALIGNMENT
  166. if data_start_offset % alignment != 0:
  167. data_start_offset += alignment - (data_start_offset % alignment)
  168. # Check if we have enough data to read the metadata
  169. if len(raw_data) < 8 + metadata_length:
  170. raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
  171. # Extract metadata bytes and parse as JSON
  172. metadata_bytes = raw_data[8:8 + metadata_length]
  173. metadata_str = metadata_bytes.decode('utf-8')
  174. try:
  175. metadata = json.loads(metadata_str)
  176. return metadata, data_start_offset
  177. except json.JSONDecodeError as e:
  178. raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
  179. @classmethod
  180. def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
  181. """
  182. Get raw byte data from a remote file by range.
  183. If size is not specified, it will read the entire file.
  184. """
  185. import requests
  186. from urllib.parse import urlparse
  187. parsed_url = urlparse(url)
  188. if not parsed_url.scheme or not parsed_url.netloc:
  189. raise ValueError(f"Invalid URL: {url}")
  190. headers = cls._get_request_headers()
  191. if size > -1:
  192. headers["Range"] = f"bytes={start}-{start + size}"
  193. response = requests.get(url, allow_redirects=True, headers=headers)
  194. response.raise_for_status()
  195. # Get raw byte data
  196. return response.content[slice(size if size > -1 else None)]
  197. @classmethod
  198. def check_file_exist(cls, url: str) -> bool:
  199. """
  200. Check if a file exists at the given URL.
  201. Returns True if the file exists, False otherwise.
  202. """
  203. import requests
  204. from urllib.parse import urlparse
  205. parsed_url = urlparse(url)
  206. if not parsed_url.scheme or not parsed_url.netloc:
  207. raise ValueError(f"Invalid URL: {url}")
  208. try:
  209. headers = cls._get_request_headers()
  210. headers["Range"] = "bytes=0-0"
  211. response = requests.head(url, allow_redirects=True, headers=headers)
  212. # Success (2xx) or redirect (3xx)
  213. return 200 <= response.status_code < 400
  214. except requests.RequestException:
  215. return False
  216. @classmethod
  217. def _get_request_headers(cls) -> dict[str, str]:
  218. """Prepare common headers for requests."""
  219. headers = {"User-Agent": "convert_hf_to_gguf"}
  220. if os.environ.get("HF_TOKEN"):
  221. headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
  222. return headers
  223. @dataclass
  224. class LocalTensorRange:
  225. filename: Path
  226. offset: int
  227. size: int
  228. @dataclass
  229. class LocalTensor:
  230. dtype: str
  231. shape: tuple[int, ...]
  232. data_range: LocalTensorRange
  233. def mmap_bytes(self) -> np.ndarray:
  234. return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
  235. class SafetensorsLocal:
  236. """
  237. Read a safetensors file from the local filesystem.
  238. Custom parsing gives a bit more control over the memory usage.
  239. The official safetensors library doesn't expose file ranges.
  240. """
  241. ALIGNMENT = 8 # bytes
  242. tensors: dict[str, LocalTensor]
  243. def __init__(self, filename: Path):
  244. with open(filename, "rb") as f:
  245. metadata_length = int.from_bytes(f.read(8), byteorder='little')
  246. file_size = os.stat(filename).st_size
  247. if file_size < 8 + metadata_length:
  248. raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
  249. metadata_str = f.read(metadata_length).decode('utf-8')
  250. try:
  251. metadata = json.loads(metadata_str)
  252. except json.JSONDecodeError as e:
  253. raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
  254. data_start_offset = f.tell()
  255. alignment = self.ALIGNMENT
  256. if data_start_offset % alignment != 0:
  257. data_start_offset += alignment - (data_start_offset % alignment)
  258. tensors: dict[str, LocalTensor] = {}
  259. for name, meta in metadata.items():
  260. if name == "__metadata__":
  261. # ignore metadata, it's not a tensor
  262. continue
  263. tensors[name] = LocalTensor(
  264. dtype=meta["dtype"],
  265. shape=tuple(meta["shape"]),
  266. data_range=LocalTensorRange(
  267. filename,
  268. data_start_offset + meta["data_offsets"][0],
  269. meta["data_offsets"][1] - meta["data_offsets"][0],
  270. ),
  271. )
  272. # order by name (same as default safetensors behavior)
  273. # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
  274. self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
  275. def __enter__(self, *args, **kwargs):
  276. del args, kwargs # unused
  277. return self.tensors
  278. def __exit__(self, *args, **kwargs):
  279. del args, kwargs # unused