|
|
@@ -1,7 +1,11 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
from typing import Literal
|
|
|
|
|
|
+import os
|
|
|
+import json
|
|
|
+
|
|
|
|
|
|
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
|
|
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
|
|
@@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
|
|
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
|
|
|
|
|
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class RemoteTensor:
|
|
|
+ dtype: str
|
|
|
+ shape: tuple[int, ...]
|
|
|
+ offset_start: int
|
|
|
+ size: int
|
|
|
+ url: str
|
|
|
+
|
|
|
+ def data(self) -> bytearray:
|
|
|
+ # TODO: handle request errors (maybe with limited retries?)
|
|
|
+ # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
|
|
+ data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+class SafetensorRemote:
|
|
|
+ """
|
|
|
+ Uility class to handle remote safetensor files.
|
|
|
+ This class is designed to work with Hugging Face model repositories.
|
|
|
+
|
|
|
+ Example (one model has single safetensor file, the other has multiple):
|
|
|
+ for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
|
|
|
+ tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
|
|
+ print(tensors)
|
|
|
+
|
|
|
+ Example reading tensor data:
|
|
|
+ tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
|
|
+ for name, meta in tensors.items():
|
|
|
+ dtype, shape, offset_start, size, remote_safetensor_url = meta
|
|
|
+ # read the tensor data
|
|
|
+ data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
|
|
|
+ print(data)
|
|
|
+ """
|
|
|
+
|
|
|
+ BASE_DOMAIN = "https://huggingface.co"
|
|
|
+ ALIGNMENT = 8 # bytes
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
|
|
+ """
|
|
|
+ Get list of tensors from a Hugging Face model repository.
|
|
|
+
|
|
|
+ Returns a dictionary of tensor names and their metadata.
|
|
|
+ Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
|
|
|
+ """
|
|
|
+ # case 1: model has only one single model.safetensor file
|
|
|
+ is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
|
|
|
+ if is_single_file:
|
|
|
+ url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
|
|
|
+ return cls.get_list_tensors(url)
|
|
|
+
|
|
|
+ # case 2: model has multiple files
|
|
|
+ index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
|
|
|
+ is_multiple_files = cls.check_file_exist(index_url)
|
|
|
+ if is_multiple_files:
|
|
|
+ # read the index file
|
|
|
+ index_data = cls.get_data_by_range(index_url, 0)
|
|
|
+ index_str = index_data.decode('utf-8')
|
|
|
+ index_json = json.loads(index_str)
|
|
|
+ assert index_json.get("weight_map") is not None, "weight_map not found in index file"
|
|
|
+ weight_map = index_json["weight_map"]
|
|
|
+ # get the list of files
|
|
|
+ all_files = list(set(weight_map.values()))
|
|
|
+ all_files.sort() # make sure we load shard files in order
|
|
|
+ # get the list of tensors
|
|
|
+ tensors: dict[str, RemoteTensor] = {}
|
|
|
+ for file in all_files:
|
|
|
+ url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
|
|
|
+ for key, val in cls.get_list_tensors(url).items():
|
|
|
+ tensors[key] = val
|
|
|
+ return tensors
|
|
|
+
|
|
|
+ raise ValueError(f"Model {model_id} does not have any safetensor files")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
|
|
+ """
|
|
|
+ Get list of tensors from a remote safetensor file.
|
|
|
+
|
|
|
+ Returns a dictionary of tensor names and their metadata.
|
|
|
+ Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
|
|
|
+ """
|
|
|
+ metadata, data_start_offset = cls.get_metadata(url)
|
|
|
+ res: dict[str, RemoteTensor] = {}
|
|
|
+
|
|
|
+ for name, meta in metadata.items():
|
|
|
+ if name == "__metadata__":
|
|
|
+ continue
|
|
|
+ if not isinstance(meta, dict):
|
|
|
+ raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
|
|
|
+ try:
|
|
|
+ dtype = meta["dtype"]
|
|
|
+ shape = meta["shape"]
|
|
|
+ offset_start_relative, offset_end_relative = meta["data_offsets"]
|
|
|
+ size = offset_end_relative - offset_start_relative
|
|
|
+ offset_start = data_start_offset + offset_start_relative
|
|
|
+ res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
|
|
+ except KeyError as e:
|
|
|
+ raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_metadata(cls, url: str) -> tuple[dict, int]:
|
|
|
+ """
|
|
|
+ Get JSON metadata from a remote safetensor file.
|
|
|
+
|
|
|
+ Returns tuple of (metadata, data_start_offset)
|
|
|
+ """
|
|
|
+ # Request first 5MB of the file (hopefully enough for metadata)
|
|
|
+ read_size = 5 * 1024 * 1024
|
|
|
+ raw_data = cls.get_data_by_range(url, 0, read_size)
|
|
|
+
|
|
|
+ # Parse header
|
|
|
+ # First 8 bytes contain the metadata length as u64 little-endian
|
|
|
+ if len(raw_data) < 8:
|
|
|
+ raise ValueError("Not enough data to read metadata size")
|
|
|
+ metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
|
|
|
+
|
|
|
+ # Calculate the data start offset
|
|
|
+ data_start_offset = 8 + metadata_length
|
|
|
+ alignment = SafetensorRemote.ALIGNMENT
|
|
|
+ if data_start_offset % alignment != 0:
|
|
|
+ data_start_offset += alignment - (data_start_offset % alignment)
|
|
|
+
|
|
|
+ # Check if we have enough data to read the metadata
|
|
|
+ if len(raw_data) < 8 + metadata_length:
|
|
|
+ raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
|
|
|
+
|
|
|
+ # Extract metadata bytes and parse as JSON
|
|
|
+ metadata_bytes = raw_data[8:8 + metadata_length]
|
|
|
+ metadata_str = metadata_bytes.decode('utf-8')
|
|
|
+ try:
|
|
|
+ metadata = json.loads(metadata_str)
|
|
|
+ return metadata, data_start_offset
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
|
|
+ """
|
|
|
+ Get raw byte data from a remote file by range.
|
|
|
+ If size is not specified, it will read the entire file.
|
|
|
+ """
|
|
|
+ import requests
|
|
|
+ from urllib.parse import urlparse
|
|
|
+
|
|
|
+ parsed_url = urlparse(url)
|
|
|
+ if not parsed_url.scheme or not parsed_url.netloc:
|
|
|
+ raise ValueError(f"Invalid URL: {url}")
|
|
|
+
|
|
|
+ headers = cls._get_request_headers()
|
|
|
+ if size > -1:
|
|
|
+ headers["Range"] = f"bytes={start}-{start + size}"
|
|
|
+ response = requests.get(url, allow_redirects=True, headers=headers)
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ # Get raw byte data
|
|
|
+ return response.content[:size]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def check_file_exist(cls, url: str) -> bool:
|
|
|
+ """
|
|
|
+ Check if a file exists at the given URL.
|
|
|
+ Returns True if the file exists, False otherwise.
|
|
|
+ """
|
|
|
+ import requests
|
|
|
+ from urllib.parse import urlparse
|
|
|
+
|
|
|
+ parsed_url = urlparse(url)
|
|
|
+ if not parsed_url.scheme or not parsed_url.netloc:
|
|
|
+ raise ValueError(f"Invalid URL: {url}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ headers = cls._get_request_headers()
|
|
|
+ headers["Range"] = "bytes=0-0"
|
|
|
+ response = requests.head(url, allow_redirects=True, headers=headers)
|
|
|
+ # Success (2xx) or redirect (3xx)
|
|
|
+ return 200 <= response.status_code < 400
|
|
|
+ except requests.RequestException:
|
|
|
+ return False
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_request_headers(cls) -> dict[str, str]:
|
|
|
+ """Prepare common headers for requests."""
|
|
|
+ headers = {"User-Agent": "convert_hf_to_gguf"}
|
|
|
+ if os.environ.get("HF_TOKEN"):
|
|
|
+ headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
|
|
+ return headers
|