Răsfoiți Sursa

convert : use all parts in safetensors index (#17286)

Sigbjørn Skjæret 2 luni în urmă
părinte
comite
9a8860cf5d
1 a modificat fișierele cu 3 adăugiri și 2 ștergeri
  1. 3 2
      convert_hf_to_gguf.py

+ 3 - 2
convert_hf_to_gguf.py

@@ -189,10 +189,10 @@ class ModelBase:
             return tensors
             return tensors
 
 
         prefix = "model" if not self.is_mistral_format else "consolidated"
         prefix = "model" if not self.is_mistral_format else "consolidated"
-        part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
+        part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
         is_safetensors: bool = len(part_names) > 0
         is_safetensors: bool = len(part_names) > 0
         if not is_safetensors:
         if not is_safetensors:
-            part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
+            part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
 
 
         tensor_names_from_index: set[str] = set()
         tensor_names_from_index: set[str] = set()
 
 
@@ -209,6 +209,7 @@ class ModelBase:
                     if weight_map is None or not isinstance(weight_map, dict):
                     if weight_map is None or not isinstance(weight_map, dict):
                         raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
                         raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
                     tensor_names_from_index.update(weight_map.keys())
                     tensor_names_from_index.update(weight_map.keys())
+                    part_names |= set(weight_map.values())
             else:
             else:
                 weight_map = {}
                 weight_map = {}
         else:
         else: