|
|
@@ -1,10 +1,12 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
+from pathlib import Path
|
|
|
from typing import Literal
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
+import numpy as np
|
|
|
|
|
|
|
|
|
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
|
|
@@ -177,6 +179,10 @@ class SafetensorRemote:
|
|
|
except KeyError as e:
|
|
|
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
|
|
|
|
|
+ # order by name (same as default safetensors behavior)
|
|
|
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
|
|
+ res = dict(sorted(res.items(), key=lambda t: t[0]))
|
|
|
+
|
|
|
return res
|
|
|
|
|
|
@classmethod
|
|
|
@@ -266,3 +272,77 @@ class SafetensorRemote:
|
|
|
if os.environ.get("HF_TOKEN"):
|
|
|
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
|
|
return headers
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class LocalTensorRange:
|
|
|
+ filename: Path
|
|
|
+ offset: int
|
|
|
+ size: int
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class LocalTensor:
|
|
|
+ dtype: str
|
|
|
+ shape: tuple[int, ...]
|
|
|
+ data_range: LocalTensorRange
|
|
|
+
|
|
|
+ def mmap_bytes(self) -> np.ndarray:
|
|
|
+ return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
|
|
|
+
|
|
|
+
|
|
|
+class SafetensorsLocal:
|
|
|
+ """
|
|
|
+ Read a safetensors file from the local filesystem.
|
|
|
+
|
|
|
+ Custom parsing gives a bit more control over the memory usage.
|
|
|
+ The official safetensors library doesn't expose file ranges.
|
|
|
+ """
|
|
|
+ ALIGNMENT = 8 # bytes
|
|
|
+
|
|
|
+ tensors: dict[str, LocalTensor]
|
|
|
+
|
|
|
+ def __init__(self, filename: Path):
|
|
|
+ with open(filename, "rb") as f:
|
|
|
+ metadata_length = int.from_bytes(f.read(8), byteorder='little')
|
|
|
+ file_size = os.stat(filename).st_size
|
|
|
+ if file_size < 8 + metadata_length:
|
|
|
+ raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
|
|
|
+
|
|
|
+ metadata_str = f.read(metadata_length).decode('utf-8')
|
|
|
+ try:
|
|
|
+ metadata = json.loads(metadata_str)
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
|
|
|
+
|
|
|
+ data_start_offset = f.tell()
|
|
|
+ alignment = self.ALIGNMENT
|
|
|
+ if data_start_offset % alignment != 0:
|
|
|
+ data_start_offset += alignment - (data_start_offset % alignment)
|
|
|
+
|
|
|
+ tensors: dict[str, LocalTensor] = {}
|
|
|
+ for name, meta in metadata.items():
|
|
|
+ if name == "__metadata__":
|
|
|
+ # ignore metadata, it's not a tensor
|
|
|
+ continue
|
|
|
+
|
|
|
+ tensors[name] = LocalTensor(
|
|
|
+ dtype=meta["dtype"],
|
|
|
+ shape=tuple(meta["shape"]),
|
|
|
+ data_range=LocalTensorRange(
|
|
|
+ filename,
|
|
|
+ data_start_offset + meta["data_offsets"][0],
|
|
|
+ meta["data_offsets"][1] - meta["data_offsets"][0],
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # order by name (same as default safetensors behavior)
|
|
|
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
|
|
|
+ self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
|
|
|
+
|
|
|
+ def __enter__(self, *args, **kwargs):
|
|
|
+ del args, kwargs # unused
|
|
|
+ return self.tensors
|
|
|
+
|
|
|
+ def __exit__(self, *args, **kwargs):
|
|
|
+ del args, kwargs # unused
|