utility.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import Literal
  5. import os
  6. import json
  7. import numpy as np
  8. def fill_templated_filename(filename: str, output_type: str | None) -> str:
  9. # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
  10. ftype_lowercase: str = output_type.lower() if output_type is not None else ""
  11. ftype_uppercase: str = output_type.upper() if output_type is not None else ""
  12. return filename.format(ftype_lowercase,
  13. outtype=ftype_lowercase, ftype=ftype_lowercase,
  14. OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
  15. def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
  16. if model_params_count > 1e12 :
  17. # Trillions Of Parameters
  18. scaled_model_params = model_params_count * 1e-12
  19. scale_suffix = "T"
  20. elif model_params_count > 1e9 :
  21. # Billions Of Parameters
  22. scaled_model_params = model_params_count * 1e-9
  23. scale_suffix = "B"
  24. elif model_params_count > 1e6 :
  25. # Millions Of Parameters
  26. scaled_model_params = model_params_count * 1e-6
  27. scale_suffix = "M"
  28. else:
  29. # Thousands Of Parameters
  30. scaled_model_params = model_params_count * 1e-3
  31. scale_suffix = "K"
  32. fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
  33. return f"{scaled_model_params:.{fix}f}{scale_suffix}"
  34. def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
  35. if expert_count > 0:
  36. pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
  37. size_class = f"{expert_count}x{pretty_size}"
  38. else:
  39. size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
  40. return size_class
  41. def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
  42. # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
  43. if base_name is not None:
  44. name = base_name.strip().replace(' ', '-').replace('/', '-')
  45. elif model_name is not None:
  46. name = model_name.strip().replace(' ', '-').replace('/', '-')
  47. else:
  48. name = "ggml-model"
  49. parameters = f"-{size_label}" if size_label is not None else ""
  50. finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
  51. version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
  52. encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
  53. kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
  54. return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
  55. @dataclass
  56. class RemoteTensor:
  57. dtype: str
  58. shape: tuple[int, ...]
  59. offset_start: int
  60. size: int
  61. url: str
  62. def data(self) -> bytearray:
  63. # TODO: handle request errors (maybe with limited retries?)
  64. # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
  65. data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
  66. return data
  67. class SafetensorRemote:
  68. """
  69. Uility class to handle remote safetensor files.
  70. This class is designed to work with Hugging Face model repositories.
  71. Example (one model has single safetensor file, the other has multiple):
  72. for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
  73. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  74. print(tensors)
  75. Example reading tensor data:
  76. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  77. for name, meta in tensors.items():
  78. dtype, shape, offset_start, size, remote_safetensor_url = meta
  79. # read the tensor data
  80. data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
  81. print(data)
  82. """
  83. BASE_DOMAIN = "https://huggingface.co"
  84. @classmethod
  85. def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
  86. """
  87. Get list of tensors from a Hugging Face model repository.
  88. Returns a dictionary of tensor names and their metadata.
  89. Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
  90. """
  91. # case 1: model has only one single model.safetensor file
  92. is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
  93. if is_single_file:
  94. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
  95. return cls.get_list_tensors(url)
  96. # case 2: model has multiple files
  97. index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
  98. is_multiple_files = cls.check_file_exist(index_url)
  99. if is_multiple_files:
  100. # read the index file
  101. index_data = cls.get_data_by_range(index_url, 0)
  102. index_str = index_data.decode('utf-8')
  103. index_json = json.loads(index_str)
  104. assert index_json.get("weight_map") is not None, "weight_map not found in index file"
  105. weight_map = index_json["weight_map"]
  106. # get the list of files
  107. all_files = list(set(weight_map.values()))
  108. all_files.sort() # make sure we load shard files in order
  109. # get the list of tensors
  110. tensors: dict[str, RemoteTensor] = {}
  111. for file in all_files:
  112. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
  113. for key, val in cls.get_list_tensors(url).items():
  114. tensors[key] = val
  115. return tensors
  116. raise ValueError(
  117. f"No safetensor file has been found for model {model_id}."
  118. "If the repo has safetensor files, make sure the model is public or you have a "
  119. "valid Hugging Face token set in the environment variable HF_TOKEN."
  120. )
  121. @classmethod
  122. def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
  123. """
  124. Get list of tensors from a remote safetensor file.
  125. Returns a dictionary of tensor names and their metadata.
  126. Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
  127. """
  128. metadata, data_start_offset = cls.get_metadata(url)
  129. res: dict[str, RemoteTensor] = {}
  130. for name, meta in metadata.items():
  131. if name == "__metadata__":
  132. continue
  133. if not isinstance(meta, dict):
  134. raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
  135. try:
  136. dtype = meta["dtype"]
  137. shape = meta["shape"]
  138. offset_start_relative, offset_end_relative = meta["data_offsets"]
  139. size = offset_end_relative - offset_start_relative
  140. offset_start = data_start_offset + offset_start_relative
  141. res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
  142. except KeyError as e:
  143. raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
  144. # order by name (same as default safetensors behavior)
  145. # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
  146. res = dict(sorted(res.items(), key=lambda t: t[0]))
  147. return res
  148. @classmethod
  149. def get_metadata(cls, url: str) -> tuple[dict, int]:
  150. """
  151. Get JSON metadata from a remote safetensor file.
  152. Returns tuple of (metadata, data_start_offset)
  153. """
  154. # Request first 5MB of the file (hopefully enough for metadata)
  155. read_size = 5 * 1024 * 1024
  156. raw_data = cls.get_data_by_range(url, 0, read_size)
  157. # Parse header
  158. # First 8 bytes contain the metadata length as u64 little-endian
  159. if len(raw_data) < 8:
  160. raise ValueError("Not enough data to read metadata size")
  161. metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
  162. # Calculate the data start offset
  163. data_start_offset = 8 + metadata_length
  164. # Check if we have enough data to read the metadata
  165. if len(raw_data) < 8 + metadata_length:
  166. raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
  167. # Extract metadata bytes and parse as JSON
  168. metadata_bytes = raw_data[8:8 + metadata_length]
  169. metadata_str = metadata_bytes.decode('utf-8')
  170. try:
  171. metadata = json.loads(metadata_str)
  172. return metadata, data_start_offset
  173. except json.JSONDecodeError as e:
  174. raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
  175. @classmethod
  176. def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
  177. """
  178. Get raw byte data from a remote file by range.
  179. If size is not specified, it will read the entire file.
  180. """
  181. import requests
  182. from urllib.parse import urlparse
  183. parsed_url = urlparse(url)
  184. if not parsed_url.scheme or not parsed_url.netloc:
  185. raise ValueError(f"Invalid URL: {url}")
  186. headers = cls._get_request_headers()
  187. if size > -1:
  188. headers["Range"] = f"bytes={start}-{start + size}"
  189. response = requests.get(url, allow_redirects=True, headers=headers)
  190. response.raise_for_status()
  191. # Get raw byte data
  192. return response.content[slice(size if size > -1 else None)]
  193. @classmethod
  194. def check_file_exist(cls, url: str) -> bool:
  195. """
  196. Check if a file exists at the given URL.
  197. Returns True if the file exists, False otherwise.
  198. """
  199. import requests
  200. from urllib.parse import urlparse
  201. parsed_url = urlparse(url)
  202. if not parsed_url.scheme or not parsed_url.netloc:
  203. raise ValueError(f"Invalid URL: {url}")
  204. try:
  205. headers = cls._get_request_headers()
  206. headers["Range"] = "bytes=0-0"
  207. response = requests.head(url, allow_redirects=True, headers=headers)
  208. # Success (2xx) or redirect (3xx)
  209. return 200 <= response.status_code < 400
  210. except requests.RequestException:
  211. return False
  212. @classmethod
  213. def _get_request_headers(cls) -> dict[str, str]:
  214. """Prepare common headers for requests."""
  215. headers = {"User-Agent": "convert_hf_to_gguf"}
  216. if os.environ.get("HF_TOKEN"):
  217. headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
  218. return headers
  219. @dataclass
  220. class LocalTensorRange:
  221. filename: Path
  222. offset: int
  223. size: int
  224. @dataclass
  225. class LocalTensor:
  226. dtype: str
  227. shape: tuple[int, ...]
  228. data_range: LocalTensorRange
  229. def mmap_bytes(self) -> np.ndarray:
  230. return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
  231. class SafetensorsLocal:
  232. """
  233. Read a safetensors file from the local filesystem.
  234. Custom parsing gives a bit more control over the memory usage.
  235. The official safetensors library doesn't expose file ranges.
  236. """
  237. tensors: dict[str, LocalTensor]
  238. def __init__(self, filename: Path):
  239. with open(filename, "rb") as f:
  240. metadata_length = int.from_bytes(f.read(8), byteorder='little')
  241. file_size = os.stat(filename).st_size
  242. if file_size < 8 + metadata_length:
  243. raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
  244. metadata_str = f.read(metadata_length).decode('utf-8')
  245. try:
  246. metadata = json.loads(metadata_str)
  247. except json.JSONDecodeError as e:
  248. raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
  249. data_start_offset = f.tell()
  250. tensors: dict[str, LocalTensor] = {}
  251. for name, meta in metadata.items():
  252. if name == "__metadata__":
  253. # ignore metadata, it's not a tensor
  254. continue
  255. tensors[name] = LocalTensor(
  256. dtype=meta["dtype"],
  257. shape=tuple(meta["shape"]),
  258. data_range=LocalTensorRange(
  259. filename,
  260. data_start_offset + meta["data_offsets"][0],
  261. meta["data_offsets"][1] - meta["data_offsets"][0],
  262. ),
  263. )
  264. # order by name (same as default safetensors behavior)
  265. # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
  266. self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
  267. def __enter__(self, *args, **kwargs):
  268. del args, kwargs # unused
  269. return self.tensors
  270. def __exit__(self, *args, **kwargs):
  271. del args, kwargs # unused