소스 검색

convert : handle mmproj filename/path properly (#16760)

* convert: handle mmproj model output filename properly

* remove redundant commits

* Add model_type to gguf utility

* Use mmproj- prefix instead of suffix

* Apply CISC suggestion

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Galunid 2 달 전
부모
커밋
5d195f17bc
1개의 변경된 파일11개의 추가작업 그리고 4개의 파일을 삭제
  1. 11 4
      convert_hf_to_gguf.py

+ 11 - 4
convert_hf_to_gguf.py

@@ -1497,6 +1497,17 @@ class MmprojModel(ModelBase):
     def set_type(self):
         self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
 
+    def prepare_metadata(self, vocab_only: bool):
+        super().prepare_metadata(vocab_only=vocab_only)
+
+        output_type: str = self.ftype.name.partition("_")[2]
+
+        if self.fname_out.is_dir():
+            fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
+            self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
+        else:
+            self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
+
     def set_gguf_parameters(self):
         self.gguf_writer.add_file_type(self.ftype)
 
@@ -9729,10 +9740,6 @@ def main() -> None:
 
     logger.info(f"Loading model: {dir_model.name}")
 
-    if args.mmproj:
-        if "mmproj" not in fname_out.name:
-            fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
-
     is_mistral_format = args.mistral_format
     if is_mistral_format and not _mistral_common_installed:
         raise ImportError(_mistral_import_error_msg)