Sfoglia il codice sorgente

convert: add ability to convert safetensors files (#1276)

* when loading a safetensors file, ignore the metadata header
* check for safetensors files first, and only use PyTorch versions when safetensors aren't available
ubik2 2 anni fa
parent
commit
95078cc554
1 ha cambiato i file con 7 aggiunte e 3 eliminazioni
  1. 7 3
      convert.py

+ 7 - 3
convert.py

@@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
             return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
         description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
         return LazyTensor(load, shape, data_type, description)
-    model = {name: convert(info) for (name, info) in header.items()}
+    model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
     return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
 
 
@@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus:
     '''Load a model of any supported format.'''
     # Be extra-friendly and accept either a file or a directory:
     if path.is_dir():
-        globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
-        files = [file for glob in globs for file in path.glob(glob)]
+        # Check if it's a set of safetensors files first
+        files = list(path.glob("model-00001-of-*.safetensors"))
+        if not files:
+            # Try the PyTorch patterns too, with lower priority
+            globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
+            files = [file for glob in globs for file in path.glob(glob)]
         if not files:
             # Try GGML too, but with lower priority, since if both a non-GGML
             # model and a GGML model exist in the same directory, we assume the