Procházet zdrojové kódy

convert : workaround for AutoConfig dummy labels (#13881)

Sigbjørn Skjæret před 7 měsíci
rodič
revize
5ca82fc1d7
1 změnil soubory, kde provedl 9 přidání a 3 odebrání
  1. 9 3
      convert_hf_to_gguf.py

+ 9 - 3
convert_hf_to_gguf.py

@@ -3690,14 +3690,20 @@ class BertModel(TextModel):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self.vocab_size = None
         self.vocab_size = None
 
 
+        if cls_out_labels := self.hparams.get("id2label"):
+            if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0":
+                # Remove dummy labels added by AutoConfig
+                cls_out_labels = None
+        self.cls_out_labels = cls_out_labels
+
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         super().set_gguf_parameters()
         self.gguf_writer.add_causal_attention(False)
         self.gguf_writer.add_causal_attention(False)
         self._try_set_pooling_type()
         self._try_set_pooling_type()
 
 
-        if cls_out_labels := self.hparams.get("id2label"):
+        if self.cls_out_labels:
             key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
             key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
-            self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
+            self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
 
 
     def set_vocab(self):
     def set_vocab(self):
         tokens, toktypes, tokpre = self.get_vocab_base()
         tokens, toktypes, tokpre = self.get_vocab_base()
@@ -3749,7 +3755,7 @@ class BertModel(TextModel):
         if name.startswith("cls.seq_relationship"):
         if name.startswith("cls.seq_relationship"):
             return []
             return []
 
 
-        if self.hparams.get("id2label"):
+        if self.cls_out_labels:
             # For BertForSequenceClassification (direct projection layer)
             # For BertForSequenceClassification (direct projection layer)
             if name == "classifier.weight":
             if name == "classifier.weight":
                 name = "classifier.out_proj.weight"
                 name = "classifier.out_proj.weight"