Kaynağa Gözat

py : handle byte tokens in `get_token_type` (#5341)

* py : handle byte tokens in `get_token_type`

* py : fix empty bytes arg
Georgi Gerganov 1 yıl önce
ebeveyn
işleme
906cff55c2
1 değiştirilmiş dosya ile 7 ekleme ve 3 silme
  1. 7 3
      convert.py

+ 7 - 3
convert.py

@@ -515,10 +515,14 @@ class HfVocab:
 
 
             # Yield token text, score, and type
             # Yield token text, score, and type
             yield token_text, self.get_token_score(token_id), self.get_token_type(
             yield token_text, self.get_token_score(token_id), self.get_token_type(
-                token_id, self.special_ids  # Reuse already stored special IDs
+                token_id, token_text, self.special_ids  # Reuse already stored special IDs
             )
             )
 
 
-    def get_token_type(self, token_id: int, special_ids: set[int]) -> gguf.TokenType:
+    def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
+        # Special case for byte tokens
+        if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
+            return gguf.TokenType.BYTE
+
         # Determine token type based on whether it's a special token
         # Determine token type based on whether it's a special token
         return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
         return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
 
 
@@ -530,7 +534,7 @@ class HfVocab:
     def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
     def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         for text in self.added_tokens_list:
         for text in self.added_tokens_list:
             if text in self.specials:
             if text in self.specials:
-                toktype = self.get_token_type(self.specials[text], self.special_ids)
+                toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
                 score = self.get_token_score(self.specials[text])
                 score = self.get_token_score(self.specials[text])
             else:
             else:
                 toktype = gguf.TokenType.USER_DEFINED
                 toktype = gguf.TokenType.USER_DEFINED