|
|
@@ -1,5 +1,6 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+from enum import Enum
|
|
|
import re
|
|
|
import logging
|
|
|
import json
|
|
|
@@ -12,6 +13,25 @@ try:
|
|
|
except ImportError:
|
|
|
SentencePieceProcessor = None
|
|
|
|
|
|
+try:
|
|
|
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
|
+ from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
|
+ from mistral_common.tokens.tokenizers.utils import (
|
|
|
+ _filter_valid_tokenizer_files,
|
|
|
+ )
|
|
|
+ from mistral_common.tokens.tokenizers.sentencepiece import (
|
|
|
+ SentencePieceTokenizer,
|
|
|
+ )
|
|
|
+except ImportError:
|
|
|
+ _mistral_common_installed = False
|
|
|
+ MistralTokenizer = None
|
|
|
+ Tekkenizer = None
|
|
|
+ SentencePieceTokenizer = None
|
|
|
+ _filter_valid_tokenizer_files = None
|
|
|
+else:
|
|
|
+ _mistral_common_installed = True
|
|
|
+
|
|
|
+
|
|
|
import gguf
|
|
|
|
|
|
from .gguf_writer import GGUFWriter
|
|
|
@@ -592,3 +612,262 @@ class LlamaHfVocab(Vocab):
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
|
+
|
|
|
+
|
|
|
+class MistralTokenizerType(str, Enum):
|
|
|
+ spm = "spm"
|
|
|
+ tekken = "tekken"
|
|
|
+
|
|
|
+
|
|
|
+# Copied from Transformers (Apache 2.0)
|
|
|
+# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
|
|
|
+
|
|
|
+def bytes_to_unicode() -> dict[int, str]:
|
|
|
+ """
|
|
|
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
|
|
+ characters the bpe code barfs on.
|
|
|
+
|
|
|
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
|
|
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
|
|
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
|
|
+ tables between utf-8 bytes and unicode strings.
|
|
|
+ """
|
|
|
+ bs = (
|
|
|
+ list(range(ord("!"), ord("~") + 1))
|
|
|
+ + list(range(ord("¡"), ord("¬") + 1))
|
|
|
+ + list(range(ord("®"), ord("ÿ") + 1))
|
|
|
+ )
|
|
|
+ cs = bs[:]
|
|
|
+ n = 0
|
|
|
+ for b in range(2**8):
|
|
|
+ if b not in bs:
|
|
|
+ bs.append(b)
|
|
|
+ cs.append(2**8 + n)
|
|
|
+ n += 1
|
|
|
+ cs_str = [chr(n) for n in cs]
|
|
|
+ return dict(zip(bs, cs_str))
|
|
|
+
|
|
|
+
|
|
|
+class MistralVocab(Vocab):
|
|
|
+ tokenizer_model = "mistral"
|
|
|
+ name = "mistral"
|
|
|
+
|
|
|
+ added_tokens_dict: dict[str, int] = {}
|
|
|
+ added_tokens_list: list[str] = []
|
|
|
+
|
|
|
+ def __init__(self, base_path: Path):
|
|
|
+ if not _mistral_common_installed:
|
|
|
+ raise ImportError(
|
|
|
+ "To use MistralVocab, please install the `mistral-common` package. "
|
|
|
+ "You can install it with `pip install mistral-common`."
|
|
|
+ )
|
|
|
+ assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
|
|
|
+ assert MistralTokenizer is not None, "mistral_common is not installed"
|
|
|
+ assert Tekkenizer is not None, "mistral_common is not installed"
|
|
|
+
|
|
|
+ logger.info(f"Loading Mistral tokenizer from {base_path}")
|
|
|
+
|
|
|
+ # Find the tokenizer files
|
|
|
+ all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
|
|
|
+ valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
|
|
+
|
|
|
+ if len(valid_tokenizer_files) == 0:
|
|
|
+ raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
|
|
+ # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
|
|
+ if len(valid_tokenizer_files) > 1:
|
|
|
+ if "tekken.json" in valid_tokenizer_files:
|
|
|
+ tokenizer_file = "tekken.json"
|
|
|
+ else:
|
|
|
+ tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
|
|
+ logger.warning(
|
|
|
+ f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ tokenizer_file = valid_tokenizer_files[0]
|
|
|
+
|
|
|
+ self.tokenizer = MistralTokenizer.from_file(
|
|
|
+ base_path / tokenizer_file
|
|
|
+ ).instruct_tokenizer.tokenizer
|
|
|
+ self.tokenizer_type = (
|
|
|
+ MistralTokenizerType.tekken
|
|
|
+ if isinstance(self.tokenizer, Tekkenizer)
|
|
|
+ else MistralTokenizerType.spm
|
|
|
+ )
|
|
|
+ self.vocab_size = self.tokenizer.n_words
|
|
|
+ self.fname_tokenizer = base_path / tokenizer_file
|
|
|
+ self._name = (
|
|
|
+ "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
|
|
|
+ )
|
|
|
+
|
|
|
+ @property
|
|
|
+ def tokenizer_name(self) -> str:
|
|
|
+ return self._name
|
|
|
+
|
|
|
+ @property
|
|
|
+ def gguf_tokenizer_model(self) -> str:
|
|
|
+ return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
|
|
|
+
|
|
|
+ def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
+ assert SentencePieceTokenizer is not None, "mistral_common is not installed"
|
|
|
+ assert isinstance(self.tokenizer, SentencePieceTokenizer), (
|
|
|
+ f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ for i in range(self.tokenizer._model.vocab_size()):
|
|
|
+ piece = self.tokenizer._model.IdToPiece(i)
|
|
|
+ text = piece.encode("utf-8")
|
|
|
+ score: float = self.tokenizer._model.GetScore(i)
|
|
|
+
|
|
|
+ toktype = gguf.TokenType.NORMAL
|
|
|
+ if self.tokenizer._model.IsUnknown(i):
|
|
|
+ toktype = gguf.TokenType.UNKNOWN
|
|
|
+ if self.tokenizer._model.IsControl(i):
|
|
|
+ toktype = gguf.TokenType.CONTROL
|
|
|
+
|
|
|
+ if self.tokenizer._model.IsUnused(i):
|
|
|
+ toktype = gguf.TokenType.UNUSED
|
|
|
+ if self.tokenizer._model.IsByte(i):
|
|
|
+ toktype = gguf.TokenType.BYTE
|
|
|
+
|
|
|
+ yield text, score, toktype
|
|
|
+
|
|
|
+ def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
+ assert Tekkenizer is not None, "mistral_common is not installed"
|
|
|
+ assert isinstance(self.tokenizer, Tekkenizer), (
|
|
|
+ f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ byte_encoder = bytes_to_unicode()
|
|
|
+ for token_id in range(self.tokenizer.num_special_tokens):
|
|
|
+ yield (
|
|
|
+ self.tokenizer.id_to_piece(token_id).encode("utf-8"),
|
|
|
+ 0,
|
|
|
+ gguf.TokenType.CONTROL
|
|
|
+ )
|
|
|
+ for token in self.tokenizer._tekken_token2id_nospecial:
|
|
|
+ yield (
|
|
|
+ self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
|
|
|
+ 0,
|
|
|
+ gguf.TokenType.NORMAL,
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_token_id(self, token: str) -> int:
|
|
|
+ assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
|
|
|
+ if self.tokenizer_type == MistralTokenizerType.spm:
|
|
|
+ assert isinstance(self.tokenizer, SentencePieceTokenizer)
|
|
|
+ return self.tokenizer._vocab.index(token)
|
|
|
+ elif self.tokenizer_type == MistralTokenizerType.tekken:
|
|
|
+ assert isinstance(self.tokenizer, Tekkenizer)
|
|
|
+ return (
|
|
|
+ self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
|
|
+
|
|
|
+ @property
|
|
|
+ def bos_id(self) -> int:
|
|
|
+ return self.tokenizer.bos_id
|
|
|
+
|
|
|
+ @property
|
|
|
+ def eos_id(self) -> int:
|
|
|
+ return self.tokenizer.eos_id
|
|
|
+
|
|
|
+ @property
|
|
|
+ def pad_id(self) -> int:
|
|
|
+ if self.tokenizer.pad_id == -1:
|
|
|
+ return self.eos_id
|
|
|
+ return self.tokenizer.pad_id
|
|
|
+
|
|
|
+ @property
|
|
|
+ def unk_id(self) -> int:
|
|
|
+ return self.tokenizer.unk_id
|
|
|
+
|
|
|
+ @property
|
|
|
+ def bos_token(self) -> str:
|
|
|
+ return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def eos_token(self) -> str:
|
|
|
+ return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def pad_token(self) -> str:
|
|
|
+ return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def unk_token(self) -> str:
|
|
|
+ return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
|
|
|
+
|
|
|
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
|
+ if self.tokenizer_type == MistralTokenizerType.spm:
|
|
|
+ yield from self._sentencepiece_tokens()
|
|
|
+
|
|
|
+ elif self.tokenizer_type == MistralTokenizerType.tekken:
|
|
|
+ yield from self._tekken_tokens()
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def token_bytes_to_string(b, byte_encoder):
|
|
|
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
|
|
+
|
|
|
+ def extract_vocab_merges_from_model(self):
|
|
|
+ # Adapted from Transformers (Apache 2.0)
|
|
|
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
|
|
|
+ assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
|
|
|
+ f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
|
|
+ )
|
|
|
+ mergeable_ranks = self.tokenizer._model._mergeable_ranks
|
|
|
+ token_bytes_map = {
|
|
|
+ rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
|
|
|
+ }
|
|
|
+ merge_pairs = []
|
|
|
+
|
|
|
+ # Sort vocab by rank to ensure correct merge order
|
|
|
+ for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
|
|
|
+ merged_token = token_bytes_map[i]
|
|
|
+ local = []
|
|
|
+ for j in range(1, len(merged_token)):
|
|
|
+ left = merged_token[:j]
|
|
|
+ right = merged_token[j:]
|
|
|
+ if (
|
|
|
+ left in mergeable_ranks
|
|
|
+ and right in mergeable_ranks
|
|
|
+ and (left + right) in mergeable_ranks
|
|
|
+ ):
|
|
|
+ local.append((left, right, i))
|
|
|
+ if not local:
|
|
|
+ raise ValueError(
|
|
|
+ f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
|
|
|
+ )
|
|
|
+ local = sorted(
|
|
|
+ local,
|
|
|
+ key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
|
|
|
+ reverse=False,
|
|
|
+ )
|
|
|
+ merge_pairs.extend(local)
|
|
|
+ merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
|
|
|
+
|
|
|
+ byte_encoder = bytes_to_unicode()
|
|
|
+
|
|
|
+ decoded_merge_pairs = [
|
|
|
+ [
|
|
|
+ self.token_bytes_to_string(val[0], byte_encoder),
|
|
|
+ self.token_bytes_to_string(val[1], byte_encoder),
|
|
|
+ ]
|
|
|
+ for val in merge_pairs
|
|
|
+ ]
|
|
|
+
|
|
|
+ merges = [
|
|
|
+ " ".join(
|
|
|
+ [
|
|
|
+ # ensure the spaces are properly encoded
|
|
|
+ "".join(chr(ord(c) + 256) if c == " " else c for c in part)
|
|
|
+ for part in pair
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ for pair in decoded_merge_pairs
|
|
|
+ ]
|
|
|
+
|
|
|
+ return merges
|