|
|
@@ -3695,6 +3695,10 @@ class BertModel(TextModel):
|
|
|
self.gguf_writer.add_causal_attention(False)
|
|
|
self._try_set_pooling_type()
|
|
|
|
|
|
+ if cls_out_labels := self.hparams.get("id2label"):
|
|
|
+ key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
|
|
+ self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
|
|
|
+
|
|
|
def set_vocab(self):
|
|
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
|
|
self.vocab_size = len(tokens)
|
|
|
@@ -3745,12 +3749,13 @@ class BertModel(TextModel):
|
|
|
if name.startswith("cls.seq_relationship"):
|
|
|
return []
|
|
|
|
|
|
- # For BertForSequenceClassification (direct projection layer)
|
|
|
- if name == "classifier.weight":
|
|
|
- name = "classifier.out_proj.weight"
|
|
|
+ if self.hparams.get("id2label"):
|
|
|
+ # For BertForSequenceClassification (direct projection layer)
|
|
|
+ if name == "classifier.weight":
|
|
|
+ name = "classifier.out_proj.weight"
|
|
|
|
|
|
- if name == "classifier.bias":
|
|
|
- name = "classifier.out_proj.bias"
|
|
|
+ if name == "classifier.bias":
|
|
|
+ name = "classifier.out_proj.bias"
|
|
|
|
|
|
return [(self.map_tensor_name(name), data_torch)]
|
|
|
|
|
|
@@ -3846,7 +3851,7 @@ class BertModel(TextModel):
|
|
|
self.gguf_writer.add_add_eos_token(True)
|
|
|
|
|
|
|
|
|
-@ModelBase.register("RobertaModel")
|
|
|
+@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
|
|
|
class RobertaModel(BertModel):
|
|
|
model_arch = gguf.MODEL_ARCH.BERT
|
|
|
|