vocab.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. from __future__ import annotations
  2. from enum import Enum
  3. import re
  4. import logging
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
  9. try:
  10. from sentencepiece import SentencePieceProcessor
  11. except ImportError:
  12. SentencePieceProcessor = None
  13. try:
  14. from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
  15. from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
  16. from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
  17. _filter_valid_tokenizer_files,
  18. )
  19. from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
  20. SentencePieceTokenizer,
  21. )
  22. except ImportError:
  23. _mistral_common_installed = False
  24. MistralTokenizer = None
  25. Tekkenizer = None
  26. SentencePieceTokenizer = None
  27. _filter_valid_tokenizer_files = None
  28. else:
  29. _mistral_common_installed = True
  30. import gguf
  31. from .gguf_writer import GGUFWriter
  32. logger = logging.getLogger(__name__)
  33. class SpecialVocab:
  34. merges: list[str]
  35. add_special_token: dict[str, bool]
  36. special_token_ids: dict[str, int]
  37. chat_template: str | Sequence[Mapping[str, str]] | None
  38. def __init__(
  39. self, path: str | os.PathLike[str], load_merges: bool = False,
  40. special_token_types: Iterable[str] | None = None,
  41. n_vocab: int | None = None,
  42. ):
  43. self.special_token_ids = {}
  44. self.add_special_token = {}
  45. self.n_vocab = n_vocab
  46. self.load_merges = load_merges
  47. self.merges = []
  48. self.chat_template = None
  49. if special_token_types is not None:
  50. self.special_token_types = special_token_types
  51. else:
  52. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  53. self._load(Path(path))
  54. def __repr__(self) -> str:
  55. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  56. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  57. )
  58. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  59. if self.merges:
  60. if not quiet:
  61. logger.info(f'Adding {len(self.merges)} merge(s).')
  62. gw.add_token_merges(self.merges)
  63. elif self.load_merges:
  64. logger.warning('Adding merges requested but no merges found, output may be non-functional.')
  65. for typ, tokid in self.special_token_ids.items():
  66. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  67. if id_handler is None:
  68. logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
  69. continue
  70. if not quiet:
  71. logger.info(f'Setting special token type {typ} to {tokid}')
  72. id_handler(tokid)
  73. for typ, value in self.add_special_token.items():
  74. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  75. if add_handler is None:
  76. logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
  77. continue
  78. if not quiet:
  79. logger.info(f'Setting add_{typ}_token to {value}')
  80. add_handler(value)
  81. if self.chat_template is not None:
  82. if not quiet:
  83. logger.info(f'Setting chat_template to {self.chat_template}')
  84. gw.add_chat_template(self.chat_template)
  85. def _load(self, path: Path) -> None:
  86. self._try_load_from_tokenizer_json(path)
  87. self._try_load_from_config_json(path)
  88. if self.load_merges and not self.merges:
  89. self._try_load_merges_txt(path)
  90. def _try_load_merges_txt(self, path: Path) -> bool:
  91. merges_file = path / 'merges.txt'
  92. if not merges_file.is_file():
  93. return False
  94. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  95. first_line = next(fp, '').strip()
  96. if not first_line.startswith('#'):
  97. fp.seek(0)
  98. line_num = 0
  99. else:
  100. line_num = 1
  101. merges = []
  102. for line in fp:
  103. line_num += 1
  104. line = line.strip()
  105. if not line:
  106. continue
  107. parts = line.split(None, 3)
  108. if len(parts) != 2:
  109. logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
  110. continue
  111. merges.append(f'{parts[0]} {parts[1]}')
  112. self.merges = merges
  113. return True
  114. def _set_special_token(self, typ: str, tid: Any) -> None:
  115. if not isinstance(tid, int):
  116. return
  117. if tid < 0:
  118. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  119. if self.n_vocab is None or tid < self.n_vocab:
  120. if typ in self.special_token_ids:
  121. return
  122. self.special_token_ids[typ] = tid
  123. return
  124. logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
  125. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  126. tokenizer = None
  127. tokenizer_file = path / 'tokenizer.json'
  128. if tokenizer_file.is_file():
  129. with open(tokenizer_file, encoding = 'utf-8') as f:
  130. tokenizer = json.load(f)
  131. if self.load_merges:
  132. merges = tokenizer.get('model', {}).get('merges')
  133. if isinstance(merges, list) and merges:
  134. if isinstance(merges[0], str):
  135. self.merges = merges
  136. elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
  137. # New format since transformers 4.45 to support spaces in merges
  138. # ref: https://github.com/ggml-org/llama.cpp/issues/9692
  139. # TODO: internally store as the new format instead of converting to old
  140. if any(' ' in s for pair in merges for s in pair):
  141. logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
  142. self.merges = [
  143. ' '.join(
  144. [
  145. # ensure the spaces are properly encoded
  146. ''.join(
  147. chr(ord(c) + 256) if c == ' ' else c
  148. for c in part
  149. )
  150. for part in pair
  151. ]
  152. )
  153. for pair in merges
  154. ]
  155. else:
  156. raise ValueError("Unknown tokenizer merges format")
  157. added_tokens = tokenizer.get('added_tokens', {})
  158. else:
  159. added_tokens = {}
  160. tokenizer_config = None
  161. tokenizer_config_file = path / 'tokenizer_config.json'
  162. if tokenizer_config_file.is_file():
  163. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  164. tokenizer_config = json.load(f)
  165. if tokenizer:
  166. special_bos = (tokenizer_config or {}).get('bos_token')
  167. special_cls = (tokenizer_config or {}).get('cls_token')
  168. special_eos = (tokenizer_config or {}).get('eos_token')
  169. special_sep = (tokenizer_config or {}).get('sep_token')
  170. if not special_bos and special_cls and tokenizer_config:
  171. tokenizer_config['bos_token'] = special_bos = special_cls
  172. if not special_eos and special_sep and tokenizer_config:
  173. tokenizer_config['eos_token'] = special_eos = special_sep
  174. if post_processor := tokenizer.get('post_processor'):
  175. for processor in post_processor.get('processors', [post_processor]):
  176. if processor.get('type') == 'RobertaProcessing':
  177. self.add_special_token['bos'] = True
  178. self.add_special_token['eos'] = True
  179. self.add_special_token['sep'] = True
  180. if not special_cls and tokenizer_config:
  181. special_cls = processor.get('cls', [special_bos])[0]
  182. tokenizer_config['cls_token'] = special_cls
  183. if not special_sep and tokenizer_config:
  184. special_sep = processor.get('sep', [special_eos])[0]
  185. tokenizer_config['sep_token'] = special_sep
  186. continue
  187. # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
  188. # Only works with simple templates, **will** get it wrong on unusual sequences
  189. if processor.get('type') == 'TemplateProcessing':
  190. tmpl_single = processor.get('single', [])
  191. tmpl_pair = processor.get('pair', [])
  192. special_first = None
  193. special_last = None
  194. if len(tmpl_single) > 1:
  195. if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
  196. if not tokenizer_config:
  197. special_bos = special_first
  198. self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
  199. if special_first not in (special_bos, special_cls):
  200. logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
  201. if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
  202. if not tokenizer_config:
  203. special_eos = special_last
  204. elif special_last != special_eos:
  205. if 'eot' not in self.special_token_types:
  206. self.special_token_types = tuple(self.special_token_types) + ('eot', )
  207. tokenizer_config['eot_token'] = special_eos
  208. elif 'eom' not in self.special_token_types:
  209. self.special_token_types = tuple(self.special_token_types) + ('eom', )
  210. tokenizer_config['eom_token'] = special_eos
  211. else:
  212. logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
  213. tokenizer_config['eos_token'] = special_eos = special_last
  214. self.add_special_token['eos'] = True if special_last == special_eos else False
  215. if special_last != special_eos:
  216. logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
  217. if tmpl_pair:
  218. seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
  219. seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
  220. if (special_first and seq_start == 0) or (special_last and seq_stop is None):
  221. logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
  222. if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
  223. tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
  224. tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
  225. if tmpl_a != 'A' or tmpl_b != 'B':
  226. logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
  227. # A [sep] [eos] B
  228. if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
  229. add_sep = False
  230. if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
  231. if special_entry in (special_sep, special_eos) and not special_last:
  232. add_sep = True
  233. if special_entry not in (special_sep, special_eos):
  234. logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
  235. else:
  236. logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
  237. if len(tmpl_pair) == 2:
  238. if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
  239. if special_entry in (special_sep, special_eos):
  240. add_sep = True
  241. if special_entry not in (special_sep, special_eos):
  242. logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
  243. else:
  244. logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
  245. self.add_special_token['sep'] = add_sep
  246. if add_sep and not special_sep and tokenizer_config:
  247. tokenizer_config['sep_token'] = special_eos
  248. continue
  249. if not tokenizer_config:
  250. return True
  251. chat_template_alt = None
  252. chat_template_json = path / 'chat_template.json'
  253. chat_template_jinja = path / 'chat_template.jinja'
  254. if chat_template_jinja.is_file():
  255. with open(chat_template_jinja, encoding = 'utf-8') as f:
  256. chat_template_alt = f.read()
  257. if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
  258. chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
  259. for template_path in additional_templates:
  260. with open(template_path, encoding = 'utf-8') as fp:
  261. chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
  262. elif chat_template_json.is_file():
  263. with open(chat_template_json, encoding = 'utf-8') as f:
  264. chat_template_alt = json.load(f).get('chat_template')
  265. chat_template = tokenizer_config.get('chat_template', chat_template_alt)
  266. if chat_template is None or isinstance(chat_template, (str, list)):
  267. self.chat_template = chat_template
  268. else:
  269. logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
  270. for typ in self.special_token_types:
  271. add_entry = tokenizer_config.get(f'add_{typ}_token')
  272. if isinstance(add_entry, bool):
  273. self.add_special_token[typ] = add_entry
  274. entry = tokenizer_config.get(f'{typ}_token')
  275. if isinstance(entry, str):
  276. tc_content = entry
  277. elif isinstance(entry, dict):
  278. entry_content = entry.get('content')
  279. if not isinstance(entry_content, str):
  280. continue
  281. tc_content = entry_content
  282. else:
  283. continue
  284. # We only need the first match here.
  285. maybe_token_id = next(
  286. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  287. None,
  288. )
  289. self._set_special_token(typ, maybe_token_id)
  290. return True
  291. def _try_load_from_config_json(self, path: Path) -> bool:
  292. config_file = path / 'config.json'
  293. if not config_file.is_file():
  294. return False
  295. with open(config_file, encoding = 'utf-8') as f:
  296. config = json.load(f)
  297. for typ in self.special_token_types:
  298. token_id = config.get(f'{typ}_token_id')
  299. # If not found at root, check in text_config (for multimodal models like Kimi-VL)
  300. if token_id is None and 'text_config' in config:
  301. token_id = config['text_config'].get(f'{typ}_token_id')
  302. self._set_special_token(typ, token_id)
  303. return True
  304. @runtime_checkable
  305. class BaseVocab(Protocol):
  306. tokenizer_model: ClassVar[str]
  307. name: ClassVar[str]
  308. @runtime_checkable
  309. class Vocab(BaseVocab, Protocol):
  310. vocab_size: int
  311. added_tokens_dict: dict[str, int]
  312. added_tokens_list: list[str]
  313. fname_tokenizer: Path
  314. def __init__(self, base_path: Path): ...
  315. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
  316. class NoVocab(BaseVocab):
  317. tokenizer_model = "no_vocab"
  318. name = "no_vocab"
  319. def __repr__(self) -> str:
  320. return "<NoVocab for a model without integrated vocabulary>"
  321. class BpeVocab(Vocab):
  322. tokenizer_model = "gpt2"
  323. name = "bpe"
  324. def __init__(self, base_path: Path):
  325. added_tokens: dict[str, int] = {}
  326. if (fname_tokenizer := base_path / 'vocab.json').exists():
  327. # "slow" tokenizer
  328. with open(fname_tokenizer, encoding="utf-8") as f:
  329. self.vocab = json.load(f)
  330. try:
  331. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
  332. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  333. added_tokens = json.load(f)
  334. except FileNotFoundError:
  335. pass
  336. else:
  337. # "fast" tokenizer
  338. fname_tokenizer = base_path / 'tokenizer.json'
  339. # if this fails, FileNotFoundError propagates to caller
  340. with open(fname_tokenizer, encoding="utf-8") as f:
  341. tokenizer_json = json.load(f)
  342. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  343. if (
  344. tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
  345. or tokenizer_json['decoder']['type'] != 'ByteLevel'
  346. ):
  347. raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
  348. self.vocab = tokenizer_model["vocab"]
  349. if (added := tokenizer_json.get('added_tokens')) is not None:
  350. # Added tokens here can be duplicates of the main vocabulary.
  351. added_tokens = {item['content']: item['id']
  352. for item in added
  353. if item['content'] not in self.vocab}
  354. vocab_size = len(self.vocab)
  355. expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
  356. actual_ids = sorted(added_tokens.values())
  357. if expected_ids != actual_ids:
  358. expected_end_id = vocab_size + len(actual_ids) - 1
  359. raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
  360. f"{vocab_size} - {expected_end_id}; got {actual_ids}")
  361. items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
  362. self.added_tokens_dict = added_tokens
  363. self.added_tokens_list = [text for (text, idx) in items]
  364. self.vocab_size_base = vocab_size
  365. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  366. self.fname_tokenizer = fname_tokenizer
  367. def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  368. reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
  369. for i, _ in enumerate(self.vocab):
  370. yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
  371. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  372. for text in self.added_tokens_list:
  373. score = -1000.0
  374. yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
  375. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  376. yield from self.bpe_tokens()
  377. yield from self.added_tokens()
  378. def __repr__(self) -> str:
  379. return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  380. class SentencePieceVocab(Vocab):
  381. tokenizer_model = "llama"
  382. name = "spm"
  383. def __init__(self, base_path: Path):
  384. if SentencePieceProcessor is None:
  385. raise RuntimeError("sentencepiece is not installed")
  386. added_tokens: dict[str, int] = {}
  387. if (fname_tokenizer := base_path / 'tokenizer.model').exists():
  388. # normal location
  389. try:
  390. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  391. added_tokens = json.load(f)
  392. except FileNotFoundError:
  393. pass
  394. elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
  395. # not found in alternate location either
  396. raise FileNotFoundError('Cannot find tokenizer.model')
  397. self.sentencepiece_tokenizer = SentencePieceProcessor()
  398. self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
  399. vocab_size = self.sentencepiece_tokenizer.vocab_size()
  400. new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
  401. expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
  402. actual_new_ids = sorted(new_tokens.keys())
  403. if expected_new_ids != actual_new_ids:
  404. raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
  405. # Token pieces that were added to the base vocabulary.
  406. self.added_tokens_dict = added_tokens
  407. self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
  408. self.vocab_size_base = vocab_size
  409. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  410. self.fname_tokenizer = fname_tokenizer
  411. def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  412. tokenizer = self.sentencepiece_tokenizer
  413. for i in range(tokenizer.vocab_size()):
  414. piece = tokenizer.IdToPiece(i)
  415. text = piece.encode("utf-8")
  416. score: float = tokenizer.GetScore(i)
  417. toktype = gguf.TokenType.NORMAL
  418. if tokenizer.IsUnknown(i):
  419. toktype = gguf.TokenType.UNKNOWN
  420. if tokenizer.IsControl(i):
  421. toktype = gguf.TokenType.CONTROL
  422. # NOTE: I think added_tokens are user defined.
  423. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
  424. # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
  425. if tokenizer.IsUnused(i):
  426. toktype = gguf.TokenType.UNUSED
  427. if tokenizer.IsByte(i):
  428. toktype = gguf.TokenType.BYTE
  429. yield text, score, toktype
  430. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  431. for text in self.added_tokens_list:
  432. score = -1000.0
  433. yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
  434. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  435. yield from self.sentencepiece_tokens()
  436. yield from self.added_tokens()
  437. def __repr__(self) -> str:
  438. return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  439. class LlamaHfVocab(Vocab):
  440. tokenizer_model = "llama"
  441. name = "hfft"
  442. def __init__(self, base_path: Path):
  443. fname_tokenizer = base_path / 'tokenizer.json'
  444. # if this fails, FileNotFoundError propagates to caller
  445. with open(fname_tokenizer, encoding='utf-8') as f:
  446. tokenizer_json = json.load(f)
  447. # pre-check so we know if we need transformers
  448. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  449. is_llama3 = (
  450. tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
  451. and not tokenizer_model.get('byte_fallback', True)
  452. )
  453. if is_llama3:
  454. raise TypeError('Llama 3 must be converted with BpeVocab')
  455. if not is_llama3 and (
  456. tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
  457. or tokenizer_json['decoder']['type'] != 'Sequence'
  458. ):
  459. raise FileNotFoundError('Cannot find Llama BPE tokenizer')
  460. try:
  461. from transformers import AutoTokenizer
  462. except ImportError as e:
  463. raise ImportError(
  464. "To use LlamaHfVocab, please install the `transformers` package. "
  465. "You can install it with `pip install transformers`."
  466. ) from e
  467. # Allow the tokenizer to default to slow or fast versions.
  468. # Explicitly set tokenizer to use local paths.
  469. self.tokenizer = AutoTokenizer.from_pretrained(
  470. base_path,
  471. cache_dir=base_path,
  472. local_files_only=True,
  473. )
  474. assert self.tokenizer.is_fast # assume tokenizer.json is used
  475. # Initialize lists and dictionaries for added tokens
  476. self.added_tokens_list = []
  477. self.added_tokens_dict = dict()
  478. self.added_tokens_ids = set()
  479. # Process added tokens
  480. for tok, tokidx in sorted(
  481. self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
  482. ):
  483. # Only consider added tokens that are not in the base vocabulary
  484. if tokidx >= self.tokenizer.vocab_size:
  485. self.added_tokens_list.append(tok)
  486. self.added_tokens_dict[tok] = tokidx
  487. self.added_tokens_ids.add(tokidx)
  488. # Store special tokens and their IDs
  489. self.specials = {
  490. tok: self.tokenizer.get_vocab()[tok]
  491. for tok in self.tokenizer.all_special_tokens
  492. }
  493. self.special_ids = set(self.tokenizer.all_special_ids)
  494. # Set vocabulary sizes
  495. self.vocab_size_base = self.tokenizer.vocab_size
  496. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  497. self.fname_tokenizer = fname_tokenizer
  498. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  499. reverse_vocab = {
  500. id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
  501. }
  502. for token_id in range(self.vocab_size_base):
  503. # Skip processing added tokens here
  504. if token_id in self.added_tokens_ids:
  505. continue
  506. # Convert token text to bytes
  507. token_text = reverse_vocab[token_id].encode("utf-8")
  508. # Yield token text, score, and type
  509. yield token_text, self.get_token_score(token_id), self.get_token_type(
  510. token_id, token_text, self.special_ids # Reuse already stored special IDs
  511. )
  512. def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
  513. # Special case for byte tokens
  514. if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  515. return gguf.TokenType.BYTE
  516. # Determine token type based on whether it's a special token
  517. return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
  518. def get_token_score(self, token_id: int) -> float:
  519. # Placeholder for actual logic to determine the token's score
  520. # This needs to be implemented based on specific requirements
  521. return -1000.0 # Default score
  522. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  523. for text in self.added_tokens_list:
  524. if text in self.specials:
  525. toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
  526. score = self.get_token_score(self.specials[text])
  527. else:
  528. toktype = gguf.TokenType.USER_DEFINED
  529. score = -1000.0
  530. yield text.encode("utf-8"), score, toktype
  531. def has_newline_token(self):
  532. return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
  533. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  534. yield from self.hf_tokens()
  535. yield from self.added_tokens()
  536. def __repr__(self) -> str:
  537. return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  538. class MistralTokenizerType(str, Enum):
  539. spm = "spm"
  540. tekken = "tekken"
  541. # Copied from Transformers (Apache 2.0)
  542. # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
  543. def bytes_to_unicode() -> dict[int, str]:
  544. """
  545. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  546. characters the bpe code barfs on.
  547. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  548. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  549. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  550. tables between utf-8 bytes and unicode strings.
  551. """
  552. bs = (
  553. list(range(ord("!"), ord("~") + 1))
  554. + list(range(ord("¡"), ord("¬") + 1))
  555. + list(range(ord("®"), ord("ÿ") + 1))
  556. )
  557. cs = bs[:]
  558. n = 0
  559. for b in range(2**8):
  560. if b not in bs:
  561. bs.append(b)
  562. cs.append(2**8 + n)
  563. n += 1
  564. cs_str = [chr(n) for n in cs]
  565. return dict(zip(bs, cs_str))
  566. class MistralVocab(Vocab):
  567. tokenizer_model = "mistral"
  568. name = "mistral"
  569. added_tokens_dict: dict[str, int] = {}
  570. added_tokens_list: list[str] = []
  571. def __init__(self, base_path: Path):
  572. if not _mistral_common_installed:
  573. raise ImportError(
  574. "To use MistralVocab, please install the `mistral-common` package. "
  575. "You can install it with `pip install mistral-common`."
  576. )
  577. assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
  578. assert MistralTokenizer is not None, "mistral_common is not installed"
  579. assert Tekkenizer is not None, "mistral_common is not installed"
  580. logger.info(f"Loading Mistral tokenizer from {base_path}")
  581. # Find the tokenizer files
  582. all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
  583. valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
  584. if len(valid_tokenizer_files) == 0:
  585. raise ValueError(f"No tokenizer file found in the directory: {base_path}")
  586. # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
  587. if len(valid_tokenizer_files) > 1:
  588. if "tekken.json" in valid_tokenizer_files:
  589. tokenizer_file = "tekken.json"
  590. else:
  591. tokenizer_file = sorted(valid_tokenizer_files)[-1]
  592. logger.warning(
  593. f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
  594. )
  595. else:
  596. tokenizer_file = valid_tokenizer_files[0]
  597. self.tokenizer = MistralTokenizer.from_file(
  598. base_path / tokenizer_file
  599. ).instruct_tokenizer.tokenizer
  600. self.tokenizer_type = (
  601. MistralTokenizerType.tekken
  602. if isinstance(self.tokenizer, Tekkenizer)
  603. else MistralTokenizerType.spm
  604. )
  605. self.vocab_size = self.tokenizer.n_words
  606. self.fname_tokenizer = base_path / tokenizer_file
  607. self._name = (
  608. "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
  609. )
  610. @property
  611. def tokenizer_name(self) -> str:
  612. return self._name
  613. @property
  614. def gguf_tokenizer_model(self) -> str:
  615. return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
  616. def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  617. assert SentencePieceTokenizer is not None, "mistral_common is not installed"
  618. assert isinstance(self.tokenizer, SentencePieceTokenizer), (
  619. f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
  620. )
  621. for i in range(self.tokenizer._model.vocab_size()):
  622. piece = self.tokenizer._model.IdToPiece(i)
  623. text = piece.encode("utf-8")
  624. score: float = self.tokenizer._model.GetScore(i)
  625. toktype = gguf.TokenType.NORMAL
  626. if self.tokenizer._model.IsUnknown(i):
  627. toktype = gguf.TokenType.UNKNOWN
  628. if self.tokenizer._model.IsControl(i):
  629. toktype = gguf.TokenType.CONTROL
  630. if self.tokenizer._model.IsUnused(i):
  631. toktype = gguf.TokenType.UNUSED
  632. if self.tokenizer._model.IsByte(i):
  633. toktype = gguf.TokenType.BYTE
  634. yield text, score, toktype
  635. def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  636. assert Tekkenizer is not None, "mistral_common is not installed"
  637. assert isinstance(self.tokenizer, Tekkenizer), (
  638. f"Expected Tekkenizer, got {type(self.tokenizer)}"
  639. )
  640. byte_encoder = bytes_to_unicode()
  641. for token_id in range(self.tokenizer.num_special_tokens):
  642. yield (
  643. self.tokenizer.id_to_piece(token_id).encode("utf-8"),
  644. 0,
  645. gguf.TokenType.CONTROL
  646. )
  647. for token in self.tokenizer._tekken_token2id_nospecial:
  648. yield (
  649. self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
  650. 0,
  651. gguf.TokenType.NORMAL,
  652. )
  653. def get_token_id(self, token: str) -> int:
  654. assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
  655. if self.tokenizer_type == MistralTokenizerType.spm:
  656. assert isinstance(self.tokenizer, SentencePieceTokenizer)
  657. return self.tokenizer._vocab.index(token)
  658. elif self.tokenizer_type == MistralTokenizerType.tekken:
  659. assert isinstance(self.tokenizer, Tekkenizer)
  660. return (
  661. self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
  662. )
  663. else:
  664. raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
  665. @property
  666. def bos_id(self) -> int:
  667. return self.tokenizer.bos_id
  668. @property
  669. def eos_id(self) -> int:
  670. return self.tokenizer.eos_id
  671. @property
  672. def pad_id(self) -> int:
  673. if self.tokenizer.pad_id == -1:
  674. return self.eos_id
  675. return self.tokenizer.pad_id
  676. @property
  677. def unk_id(self) -> int:
  678. return self.tokenizer.unk_id
  679. @property
  680. def bos_token(self) -> str:
  681. return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
  682. @property
  683. def eos_token(self) -> str:
  684. return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
  685. @property
  686. def pad_token(self) -> str:
  687. return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
  688. @property
  689. def unk_token(self) -> str:
  690. return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
  691. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  692. if self.tokenizer_type == MistralTokenizerType.spm:
  693. yield from self._sentencepiece_tokens()
  694. elif self.tokenizer_type == MistralTokenizerType.tekken:
  695. yield from self._tekken_tokens()
  696. else:
  697. raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
  698. @staticmethod
  699. def token_bytes_to_string(b, byte_encoder):
  700. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  701. def extract_vocab_merges_from_model(self):
  702. # Adapted from Transformers (Apache 2.0)
  703. # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
  704. assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
  705. f"Expected Tekkenizer, got {type(self.tokenizer)}"
  706. )
  707. mergeable_ranks = self.tokenizer._model._mergeable_ranks
  708. token_bytes_map = {
  709. rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
  710. }
  711. merge_pairs = []
  712. # Sort vocab by rank to ensure correct merge order
  713. for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
  714. merged_token = token_bytes_map[i]
  715. local = []
  716. for j in range(1, len(merged_token)):
  717. left = merged_token[:j]
  718. right = merged_token[j:]
  719. if (
  720. left in mergeable_ranks
  721. and right in mergeable_ranks
  722. and (left + right) in mergeable_ranks
  723. ):
  724. local.append((left, right, i))
  725. if not local:
  726. raise ValueError(
  727. f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
  728. )
  729. local = sorted(
  730. local,
  731. key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
  732. reverse=False,
  733. )
  734. merge_pairs.extend(local)
  735. merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
  736. byte_encoder = bytes_to_unicode()
  737. decoded_merge_pairs = [
  738. [
  739. self.token_bytes_to_string(val[0], byte_encoder),
  740. self.token_bytes_to_string(val[1], byte_encoder),
  741. ]
  742. for val in merge_pairs
  743. ]
  744. merges = [
  745. " ".join(
  746. [
  747. # ensure the spaces are properly encoded
  748. "".join(chr(ord(c) + 256) if c == " " else c for c in part)
  749. for part in pair
  750. ]
  751. )
  752. for pair in decoded_merge_pairs
  753. ]
  754. return merges