utility.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Literal
  4. import os
  5. import json
  6. def fill_templated_filename(filename: str, output_type: str | None) -> str:
  7. # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
  8. ftype_lowercase: str = output_type.lower() if output_type is not None else ""
  9. ftype_uppercase: str = output_type.upper() if output_type is not None else ""
  10. return filename.format(ftype_lowercase,
  11. outtype=ftype_lowercase, ftype=ftype_lowercase,
  12. OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
  13. def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
  14. if model_params_count > 1e12 :
  15. # Trillions Of Parameters
  16. scaled_model_params = model_params_count * 1e-12
  17. scale_suffix = "T"
  18. elif model_params_count > 1e9 :
  19. # Billions Of Parameters
  20. scaled_model_params = model_params_count * 1e-9
  21. scale_suffix = "B"
  22. elif model_params_count > 1e6 :
  23. # Millions Of Parameters
  24. scaled_model_params = model_params_count * 1e-6
  25. scale_suffix = "M"
  26. else:
  27. # Thousands Of Parameters
  28. scaled_model_params = model_params_count * 1e-3
  29. scale_suffix = "K"
  30. fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
  31. return f"{scaled_model_params:.{fix}f}{scale_suffix}"
  32. def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
  33. if expert_count > 0:
  34. pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
  35. size_class = f"{expert_count}x{pretty_size}"
  36. else:
  37. size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
  38. return size_class
  39. def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
  40. # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
  41. if base_name is not None:
  42. name = base_name.strip().replace(' ', '-').replace('/', '-')
  43. elif model_name is not None:
  44. name = model_name.strip().replace(' ', '-').replace('/', '-')
  45. else:
  46. name = "ggml-model"
  47. parameters = f"-{size_label}" if size_label is not None else ""
  48. finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
  49. version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
  50. encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
  51. kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
  52. return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
  53. @dataclass
  54. class RemoteTensor:
  55. dtype: str
  56. shape: tuple[int, ...]
  57. offset_start: int
  58. size: int
  59. url: str
  60. def data(self) -> bytearray:
  61. # TODO: handle request errors (maybe with limited retries?)
  62. # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
  63. data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
  64. return data
  65. class SafetensorRemote:
  66. """
  67. Uility class to handle remote safetensor files.
  68. This class is designed to work with Hugging Face model repositories.
  69. Example (one model has single safetensor file, the other has multiple):
  70. for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
  71. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  72. print(tensors)
  73. Example reading tensor data:
  74. tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
  75. for name, meta in tensors.items():
  76. dtype, shape, offset_start, size, remote_safetensor_url = meta
  77. # read the tensor data
  78. data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
  79. print(data)
  80. """
  81. BASE_DOMAIN = "https://huggingface.co"
  82. ALIGNMENT = 8 # bytes
  83. @classmethod
  84. def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
  85. """
  86. Get list of tensors from a Hugging Face model repository.
  87. Returns a dictionary of tensor names and their metadata.
  88. Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
  89. """
  90. # case 1: model has only one single model.safetensor file
  91. is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
  92. if is_single_file:
  93. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
  94. return cls.get_list_tensors(url)
  95. # case 2: model has multiple files
  96. index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
  97. is_multiple_files = cls.check_file_exist(index_url)
  98. if is_multiple_files:
  99. # read the index file
  100. index_data = cls.get_data_by_range(index_url, 0)
  101. index_str = index_data.decode('utf-8')
  102. index_json = json.loads(index_str)
  103. assert index_json.get("weight_map") is not None, "weight_map not found in index file"
  104. weight_map = index_json["weight_map"]
  105. # get the list of files
  106. all_files = list(set(weight_map.values()))
  107. all_files.sort() # make sure we load shard files in order
  108. # get the list of tensors
  109. tensors: dict[str, RemoteTensor] = {}
  110. for file in all_files:
  111. url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
  112. for key, val in cls.get_list_tensors(url).items():
  113. tensors[key] = val
  114. return tensors
  115. raise ValueError(
  116. f"No safetensor file has been found for model {model_id}."
  117. "If the repo has safetensor files, make sure the model is public or you have a "
  118. "valid Hugging Face token set in the environment variable HF_TOKEN."
  119. )
  120. @classmethod
  121. def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
  122. """
  123. Get list of tensors from a remote safetensor file.
  124. Returns a dictionary of tensor names and their metadata.
  125. Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
  126. """
  127. metadata, data_start_offset = cls.get_metadata(url)
  128. res: dict[str, RemoteTensor] = {}
  129. for name, meta in metadata.items():
  130. if name == "__metadata__":
  131. continue
  132. if not isinstance(meta, dict):
  133. raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
  134. try:
  135. dtype = meta["dtype"]
  136. shape = meta["shape"]
  137. offset_start_relative, offset_end_relative = meta["data_offsets"]
  138. size = offset_end_relative - offset_start_relative
  139. offset_start = data_start_offset + offset_start_relative
  140. res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
  141. except KeyError as e:
  142. raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
  143. return res
  144. @classmethod
  145. def get_metadata(cls, url: str) -> tuple[dict, int]:
  146. """
  147. Get JSON metadata from a remote safetensor file.
  148. Returns tuple of (metadata, data_start_offset)
  149. """
  150. # Request first 5MB of the file (hopefully enough for metadata)
  151. read_size = 5 * 1024 * 1024
  152. raw_data = cls.get_data_by_range(url, 0, read_size)
  153. # Parse header
  154. # First 8 bytes contain the metadata length as u64 little-endian
  155. if len(raw_data) < 8:
  156. raise ValueError("Not enough data to read metadata size")
  157. metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
  158. # Calculate the data start offset
  159. data_start_offset = 8 + metadata_length
  160. alignment = SafetensorRemote.ALIGNMENT
  161. if data_start_offset % alignment != 0:
  162. data_start_offset += alignment - (data_start_offset % alignment)
  163. # Check if we have enough data to read the metadata
  164. if len(raw_data) < 8 + metadata_length:
  165. raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
  166. # Extract metadata bytes and parse as JSON
  167. metadata_bytes = raw_data[8:8 + metadata_length]
  168. metadata_str = metadata_bytes.decode('utf-8')
  169. try:
  170. metadata = json.loads(metadata_str)
  171. return metadata, data_start_offset
  172. except json.JSONDecodeError as e:
  173. raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
  174. @classmethod
  175. def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
  176. """
  177. Get raw byte data from a remote file by range.
  178. If size is not specified, it will read the entire file.
  179. """
  180. import requests
  181. from urllib.parse import urlparse
  182. parsed_url = urlparse(url)
  183. if not parsed_url.scheme or not parsed_url.netloc:
  184. raise ValueError(f"Invalid URL: {url}")
  185. headers = cls._get_request_headers()
  186. if size > -1:
  187. headers["Range"] = f"bytes={start}-{start + size}"
  188. response = requests.get(url, allow_redirects=True, headers=headers)
  189. response.raise_for_status()
  190. # Get raw byte data
  191. return response.content[slice(size if size > -1 else None)]
  192. @classmethod
  193. def check_file_exist(cls, url: str) -> bool:
  194. """
  195. Check if a file exists at the given URL.
  196. Returns True if the file exists, False otherwise.
  197. """
  198. import requests
  199. from urllib.parse import urlparse
  200. parsed_url = urlparse(url)
  201. if not parsed_url.scheme or not parsed_url.netloc:
  202. raise ValueError(f"Invalid URL: {url}")
  203. try:
  204. headers = cls._get_request_headers()
  205. headers["Range"] = "bytes=0-0"
  206. response = requests.head(url, allow_redirects=True, headers=headers)
  207. # Success (2xx) or redirect (3xx)
  208. return 200 <= response.status_code < 400
  209. except requests.RequestException:
  210. return False
  211. @classmethod
  212. def _get_request_headers(cls) -> dict[str, str]:
  213. """Prepare common headers for requests."""
  214. headers = {"User-Agent": "convert_hf_to_gguf"}
  215. if os.environ.get("HF_TOKEN"):
  216. headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
  217. return headers