vocab.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. from __future__ import annotations
  2. from enum import Enum
  3. import re
  4. import logging
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
  9. try:
  10. from sentencepiece import SentencePieceProcessor
  11. except ImportError:
  12. SentencePieceProcessor = None
  13. try:
  14. from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
  15. from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
  16. from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
  17. _filter_valid_tokenizer_files,
  18. )
  19. from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
  20. SentencePieceTokenizer,
  21. )
  22. except ImportError:
  23. _mistral_common_installed = False
  24. MistralTokenizer = None
  25. Tekkenizer = None
  26. SentencePieceTokenizer = None
  27. _filter_valid_tokenizer_files = None
  28. else:
  29. _mistral_common_installed = True
  30. try:
  31. from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
  32. get_one_valid_tokenizer_file,
  33. )
  34. except ImportError:
  35. # We still want the conversion to work with older mistral-common versions.
  36. get_one_valid_tokenizer_file = None
  37. import gguf
  38. from .gguf_writer import GGUFWriter
  39. logger = logging.getLogger(__name__)
  40. class SpecialVocab:
  41. merges: list[str]
  42. add_special_token: dict[str, bool]
  43. special_token_ids: dict[str, int]
  44. chat_template: str | Sequence[Mapping[str, str]] | None
  45. def __init__(
  46. self, path: str | os.PathLike[str], load_merges: bool = False,
  47. special_token_types: Iterable[str] | None = None,
  48. n_vocab: int | None = None,
  49. ):
  50. self.special_token_ids = {}
  51. self.add_special_token = {}
  52. self.n_vocab = n_vocab
  53. self.load_merges = load_merges
  54. self.merges = []
  55. self.chat_template = None
  56. if special_token_types is not None:
  57. self.special_token_types = special_token_types
  58. else:
  59. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  60. self._load(Path(path))
  61. def __repr__(self) -> str:
  62. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  63. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  64. )
  65. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  66. if self.merges:
  67. if not quiet:
  68. logger.info(f'Adding {len(self.merges)} merge(s).')
  69. gw.add_token_merges(self.merges)
  70. elif self.load_merges:
  71. logger.warning('Adding merges requested but no merges found, output may be non-functional.')
  72. for typ, tokid in self.special_token_ids.items():
  73. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  74. if id_handler is None:
  75. logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
  76. continue
  77. if not quiet:
  78. logger.info(f'Setting special token type {typ} to {tokid}')
  79. id_handler(tokid)
  80. for typ, value in self.add_special_token.items():
  81. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  82. if add_handler is None:
  83. logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
  84. continue
  85. if not quiet:
  86. logger.info(f'Setting add_{typ}_token to {value}')
  87. add_handler(value)
  88. if self.chat_template is not None:
  89. if not quiet:
  90. logger.info(f'Setting chat_template to {self.chat_template}')
  91. gw.add_chat_template(self.chat_template)
  92. def _load(self, path: Path) -> None:
  93. self._try_load_from_tokenizer_json(path)
  94. self._try_load_from_config_json(path)
  95. if self.load_merges and not self.merges:
  96. self._try_load_merges_txt(path)
  97. def _try_load_merges_txt(self, path: Path) -> bool:
  98. merges_file = path / 'merges.txt'
  99. if not merges_file.is_file():
  100. return False
  101. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  102. first_line = next(fp, '').strip()
  103. if not first_line.startswith('#'):
  104. fp.seek(0)
  105. line_num = 0
  106. else:
  107. line_num = 1
  108. merges = []
  109. for line in fp:
  110. line_num += 1
  111. line = line.strip()
  112. if not line:
  113. continue
  114. parts = line.split(None, 3)
  115. if len(parts) != 2:
  116. logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
  117. continue
  118. merges.append(f'{parts[0]} {parts[1]}')
  119. self.merges = merges
  120. return True
  121. def _set_special_token(self, typ: str, tid: Any) -> None:
  122. if not isinstance(tid, int):
  123. return
  124. if tid < 0:
  125. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  126. if self.n_vocab is None or tid < self.n_vocab:
  127. if typ in self.special_token_ids:
  128. return
  129. self.special_token_ids[typ] = tid
  130. return
  131. logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
  132. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  133. tokenizer = None
  134. tokenizer_file = path / 'tokenizer.json'
  135. if tokenizer_file.is_file():
  136. with open(tokenizer_file, encoding = 'utf-8') as f:
  137. tokenizer = json.load(f)
  138. if self.load_merges:
  139. merges = tokenizer.get('model', {}).get('merges')
  140. if isinstance(merges, list) and merges:
  141. if isinstance(merges[0], str):
  142. self.merges = merges
  143. elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
  144. # New format since transformers 4.45 to support spaces in merges
  145. # ref: https://github.com/ggml-org/llama.cpp/issues/9692
  146. # TODO: internally store as the new format instead of converting to old
  147. if any(' ' in s for pair in merges for s in pair):
  148. logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
  149. self.merges = [
  150. ' '.join(
  151. [
  152. # ensure the spaces are properly encoded
  153. ''.join(
  154. chr(ord(c) + 256) if c == ' ' else c
  155. for c in part
  156. )
  157. for part in pair
  158. ]
  159. )
  160. for pair in merges
  161. ]
  162. else:
  163. raise ValueError("Unknown tokenizer merges format")
  164. added_tokens = tokenizer.get('added_tokens', {})
  165. else:
  166. added_tokens = {}
  167. tokenizer_config = None
  168. tokenizer_config_file = path / 'tokenizer_config.json'
  169. if tokenizer_config_file.is_file():
  170. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  171. tokenizer_config = json.load(f)
  172. if tokenizer:
  173. special_bos = (tokenizer_config or {}).get('bos_token')
  174. special_cls = (tokenizer_config or {}).get('cls_token')
  175. special_eos = (tokenizer_config or {}).get('eos_token')
  176. special_sep = (tokenizer_config or {}).get('sep_token')
  177. if not special_bos and special_cls and tokenizer_config:
  178. tokenizer_config['bos_token'] = special_bos = special_cls
  179. if not special_eos and special_sep and tokenizer_config:
  180. tokenizer_config['eos_token'] = special_eos = special_sep
  181. if post_processor := tokenizer.get('post_processor'):
  182. for processor in post_processor.get('processors', [post_processor]):
  183. if processor.get('type') == 'RobertaProcessing':
  184. self.add_special_token['bos'] = True
  185. self.add_special_token['eos'] = True
  186. self.add_special_token['sep'] = True
  187. if not special_cls and tokenizer_config:
  188. special_cls = processor.get('cls', [special_bos])[0]
  189. tokenizer_config['cls_token'] = special_cls
  190. if not special_sep and tokenizer_config:
  191. special_sep = processor.get('sep', [special_eos])[0]
  192. tokenizer_config['sep_token'] = special_sep
  193. continue
  194. # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
  195. # Only works with simple templates, **will** get it wrong on unusual sequences
  196. if processor.get('type') == 'TemplateProcessing':
  197. tmpl_single = processor.get('single', [])
  198. tmpl_pair = processor.get('pair', [])
  199. special_first = None
  200. special_last = None
  201. if len(tmpl_single) > 1:
  202. if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
  203. if not tokenizer_config:
  204. special_bos = special_first
  205. self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
  206. if special_first not in (special_bos, special_cls):
  207. logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
  208. if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
  209. if not tokenizer_config:
  210. special_eos = special_last
  211. elif special_last != special_eos:
  212. if 'eot' not in self.special_token_types:
  213. self.special_token_types = tuple(self.special_token_types) + ('eot', )
  214. tokenizer_config['eot_token'] = special_eos
  215. elif 'eom' not in self.special_token_types:
  216. self.special_token_types = tuple(self.special_token_types) + ('eom', )
  217. tokenizer_config['eom_token'] = special_eos
  218. else:
  219. logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
  220. tokenizer_config['eos_token'] = special_eos = special_last
  221. self.add_special_token['eos'] = True if special_last == special_eos else False
  222. if special_last != special_eos:
  223. logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
  224. if tmpl_pair:
  225. seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
  226. seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
  227. if (special_first and seq_start == 0) or (special_last and seq_stop is None):
  228. logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
  229. if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
  230. tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
  231. tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
  232. if tmpl_a != 'A' or tmpl_b != 'B':
  233. logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
  234. # A [sep] [eos] B
  235. if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
  236. add_sep = False
  237. if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
  238. if special_entry in (special_sep, special_eos) and not special_last:
  239. add_sep = True
  240. if special_entry not in (special_sep, special_eos):
  241. logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
  242. else:
  243. logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
  244. if len(tmpl_pair) == 2:
  245. if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
  246. if special_entry in (special_sep, special_eos):
  247. add_sep = True
  248. if special_entry not in (special_sep, special_eos):
  249. logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
  250. else:
  251. logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
  252. self.add_special_token['sep'] = add_sep
  253. if add_sep and not special_sep and tokenizer_config:
  254. tokenizer_config['sep_token'] = special_eos
  255. continue
  256. if not tokenizer_config:
  257. return True
  258. chat_template_alt = None
  259. chat_template_json = path / 'chat_template.json'
  260. chat_template_jinja = path / 'chat_template.jinja'
  261. if chat_template_jinja.is_file():
  262. with open(chat_template_jinja, encoding = 'utf-8') as f:
  263. chat_template_alt = f.read()
  264. if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
  265. chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
  266. for template_path in additional_templates:
  267. with open(template_path, encoding = 'utf-8') as fp:
  268. chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
  269. elif chat_template_json.is_file():
  270. with open(chat_template_json, encoding = 'utf-8') as f:
  271. chat_template_alt = json.load(f).get('chat_template')
  272. chat_template = tokenizer_config.get('chat_template', chat_template_alt)
  273. if chat_template is None or isinstance(chat_template, (str, list)):
  274. self.chat_template = chat_template
  275. else:
  276. logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
  277. for typ in self.special_token_types:
  278. add_entry = tokenizer_config.get(f'add_{typ}_token')
  279. if isinstance(add_entry, bool):
  280. self.add_special_token[typ] = add_entry
  281. entry = tokenizer_config.get(f'{typ}_token')
  282. if isinstance(entry, str):
  283. tc_content = entry
  284. elif isinstance(entry, dict):
  285. entry_content = entry.get('content')
  286. if not isinstance(entry_content, str):
  287. continue
  288. tc_content = entry_content
  289. else:
  290. continue
  291. # We only need the first match here.
  292. maybe_token_id = next(
  293. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  294. None,
  295. )
  296. self._set_special_token(typ, maybe_token_id)
  297. return True
  298. def _try_load_from_config_json(self, path: Path) -> bool:
  299. config_file = path / 'config.json'
  300. if not config_file.is_file():
  301. return False
  302. with open(config_file, encoding = 'utf-8') as f:
  303. config = json.load(f)
  304. for typ in self.special_token_types:
  305. token_id = config.get(f'{typ}_token_id')
  306. # If not found at root, check in text_config (for multimodal models like Kimi-VL)
  307. if token_id is None and 'text_config' in config:
  308. token_id = config['text_config'].get(f'{typ}_token_id')
  309. self._set_special_token(typ, token_id)
  310. return True
  311. @runtime_checkable
  312. class BaseVocab(Protocol):
  313. tokenizer_model: ClassVar[str]
  314. name: ClassVar[str]
  315. @runtime_checkable
  316. class Vocab(BaseVocab, Protocol):
  317. vocab_size: int
  318. added_tokens_dict: dict[str, int]
  319. added_tokens_list: list[str]
  320. fname_tokenizer: Path
  321. def __init__(self, base_path: Path): ...
  322. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
  323. class NoVocab(BaseVocab):
  324. tokenizer_model = "no_vocab"
  325. name = "no_vocab"
  326. def __repr__(self) -> str:
  327. return "<NoVocab for a model without integrated vocabulary>"
  328. class BpeVocab(Vocab):
  329. tokenizer_model = "gpt2"
  330. name = "bpe"
  331. def __init__(self, base_path: Path):
  332. added_tokens: dict[str, int] = {}
  333. if (fname_tokenizer := base_path / 'vocab.json').exists():
  334. # "slow" tokenizer
  335. with open(fname_tokenizer, encoding="utf-8") as f:
  336. self.vocab = json.load(f)
  337. try:
  338. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
  339. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  340. added_tokens = json.load(f)
  341. except FileNotFoundError:
  342. pass
  343. else:
  344. # "fast" tokenizer
  345. fname_tokenizer = base_path / 'tokenizer.json'
  346. # if this fails, FileNotFoundError propagates to caller
  347. with open(fname_tokenizer, encoding="utf-8") as f:
  348. tokenizer_json = json.load(f)
  349. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  350. if (
  351. tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
  352. or tokenizer_json['decoder']['type'] != 'ByteLevel'
  353. ):
  354. raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
  355. self.vocab = tokenizer_model["vocab"]
  356. if (added := tokenizer_json.get('added_tokens')) is not None:
  357. # Added tokens here can be duplicates of the main vocabulary.
  358. added_tokens = {item['content']: item['id']
  359. for item in added
  360. if item['content'] not in self.vocab}
  361. vocab_size = len(self.vocab)
  362. expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
  363. actual_ids = sorted(added_tokens.values())
  364. if expected_ids != actual_ids:
  365. expected_end_id = vocab_size + len(actual_ids) - 1
  366. raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
  367. f"{vocab_size} - {expected_end_id}; got {actual_ids}")
  368. items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
  369. self.added_tokens_dict = added_tokens
  370. self.added_tokens_list = [text for (text, idx) in items]
  371. self.vocab_size_base = vocab_size
  372. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  373. self.fname_tokenizer = fname_tokenizer
  374. def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  375. reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
  376. for i, _ in enumerate(self.vocab):
  377. yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
  378. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  379. for text in self.added_tokens_list:
  380. score = -1000.0
  381. yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
  382. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  383. yield from self.bpe_tokens()
  384. yield from self.added_tokens()
  385. def __repr__(self) -> str:
  386. return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  387. class SentencePieceVocab(Vocab):
  388. tokenizer_model = "llama"
  389. name = "spm"
  390. def __init__(self, base_path: Path):
  391. if SentencePieceProcessor is None:
  392. raise RuntimeError("sentencepiece is not installed")
  393. added_tokens: dict[str, int] = {}
  394. if (fname_tokenizer := base_path / 'tokenizer.model').exists():
  395. # normal location
  396. try:
  397. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  398. added_tokens = json.load(f)
  399. except FileNotFoundError:
  400. pass
  401. elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
  402. # not found in alternate location either
  403. raise FileNotFoundError('Cannot find tokenizer.model')
  404. self.sentencepiece_tokenizer = SentencePieceProcessor()
  405. self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
  406. vocab_size = self.sentencepiece_tokenizer.vocab_size()
  407. new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
  408. expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
  409. actual_new_ids = sorted(new_tokens.keys())
  410. if expected_new_ids != actual_new_ids:
  411. raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
  412. # Token pieces that were added to the base vocabulary.
  413. self.added_tokens_dict = added_tokens
  414. self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
  415. self.vocab_size_base = vocab_size
  416. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  417. self.fname_tokenizer = fname_tokenizer
  418. def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  419. tokenizer = self.sentencepiece_tokenizer
  420. for i in range(tokenizer.vocab_size()):
  421. piece = tokenizer.IdToPiece(i)
  422. text = piece.encode("utf-8")
  423. score: float = tokenizer.GetScore(i)
  424. toktype = gguf.TokenType.NORMAL
  425. if tokenizer.IsUnknown(i):
  426. toktype = gguf.TokenType.UNKNOWN
  427. if tokenizer.IsControl(i):
  428. toktype = gguf.TokenType.CONTROL
  429. # NOTE: I think added_tokens are user defined.
  430. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
  431. # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
  432. if tokenizer.IsUnused(i):
  433. toktype = gguf.TokenType.UNUSED
  434. if tokenizer.IsByte(i):
  435. toktype = gguf.TokenType.BYTE
  436. yield text, score, toktype
  437. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  438. for text in self.added_tokens_list:
  439. score = -1000.0
  440. yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
  441. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  442. yield from self.sentencepiece_tokens()
  443. yield from self.added_tokens()
  444. def __repr__(self) -> str:
  445. return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  446. class LlamaHfVocab(Vocab):
  447. tokenizer_model = "llama"
  448. name = "hfft"
  449. def __init__(self, base_path: Path):
  450. fname_tokenizer = base_path / 'tokenizer.json'
  451. # if this fails, FileNotFoundError propagates to caller
  452. with open(fname_tokenizer, encoding='utf-8') as f:
  453. tokenizer_json = json.load(f)
  454. # pre-check so we know if we need transformers
  455. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  456. is_llama3 = (
  457. tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
  458. and not tokenizer_model.get('byte_fallback', True)
  459. )
  460. if is_llama3:
  461. raise TypeError('Llama 3 must be converted with BpeVocab')
  462. if not is_llama3 and (
  463. tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
  464. or tokenizer_json['decoder']['type'] != 'Sequence'
  465. ):
  466. raise FileNotFoundError('Cannot find Llama BPE tokenizer')
  467. try:
  468. from transformers import AutoTokenizer
  469. except ImportError as e:
  470. raise ImportError(
  471. "To use LlamaHfVocab, please install the `transformers` package. "
  472. "You can install it with `pip install transformers`."
  473. ) from e
  474. # Allow the tokenizer to default to slow or fast versions.
  475. # Explicitly set tokenizer to use local paths.
  476. self.tokenizer = AutoTokenizer.from_pretrained(
  477. base_path,
  478. cache_dir=base_path,
  479. local_files_only=True,
  480. )
  481. assert self.tokenizer.is_fast # assume tokenizer.json is used
  482. # Initialize lists and dictionaries for added tokens
  483. self.added_tokens_list = []
  484. self.added_tokens_dict = dict()
  485. self.added_tokens_ids = set()
  486. # Process added tokens
  487. for tok, tokidx in sorted(
  488. self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
  489. ):
  490. # Only consider added tokens that are not in the base vocabulary
  491. if tokidx >= self.tokenizer.vocab_size:
  492. self.added_tokens_list.append(tok)
  493. self.added_tokens_dict[tok] = tokidx
  494. self.added_tokens_ids.add(tokidx)
  495. # Store special tokens and their IDs
  496. self.specials = {
  497. tok: self.tokenizer.get_vocab()[tok]
  498. for tok in self.tokenizer.all_special_tokens
  499. }
  500. self.special_ids = set(self.tokenizer.all_special_ids)
  501. # Set vocabulary sizes
  502. self.vocab_size_base = self.tokenizer.vocab_size
  503. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  504. self.fname_tokenizer = fname_tokenizer
  505. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  506. reverse_vocab = {
  507. id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
  508. }
  509. for token_id in range(self.vocab_size_base):
  510. # Skip processing added tokens here
  511. if token_id in self.added_tokens_ids:
  512. continue
  513. # Convert token text to bytes
  514. token_text = reverse_vocab[token_id].encode("utf-8")
  515. # Yield token text, score, and type
  516. yield token_text, self.get_token_score(token_id), self.get_token_type(
  517. token_id, token_text, self.special_ids # Reuse already stored special IDs
  518. )
  519. def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
  520. # Special case for byte tokens
  521. if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  522. return gguf.TokenType.BYTE
  523. # Determine token type based on whether it's a special token
  524. return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
  525. def get_token_score(self, token_id: int) -> float:
  526. # Placeholder for actual logic to determine the token's score
  527. # This needs to be implemented based on specific requirements
  528. return -1000.0 # Default score
  529. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  530. for text in self.added_tokens_list:
  531. if text in self.specials:
  532. toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
  533. score = self.get_token_score(self.specials[text])
  534. else:
  535. toktype = gguf.TokenType.USER_DEFINED
  536. score = -1000.0
  537. yield text.encode("utf-8"), score, toktype
  538. def has_newline_token(self):
  539. return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
  540. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  541. yield from self.hf_tokens()
  542. yield from self.added_tokens()
  543. def __repr__(self) -> str:
  544. return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  545. class MistralTokenizerType(str, Enum):
  546. spm = "spm"
  547. tekken = "tekken"
  548. # Copied from Transformers (Apache 2.0)
  549. # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
  550. def bytes_to_unicode() -> dict[int, str]:
  551. """
  552. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  553. characters the bpe code barfs on.
  554. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  555. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  556. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  557. tables between utf-8 bytes and unicode strings.
  558. """
  559. bs = (
  560. list(range(ord("!"), ord("~") + 1))
  561. + list(range(ord("¡"), ord("¬") + 1))
  562. + list(range(ord("®"), ord("ÿ") + 1))
  563. )
  564. cs = bs[:]
  565. n = 0
  566. for b in range(2**8):
  567. if b not in bs:
  568. bs.append(b)
  569. cs.append(2**8 + n)
  570. n += 1
  571. cs_str = [chr(n) for n in cs]
  572. return dict(zip(bs, cs_str))
  573. class MistralVocab(Vocab):
  574. tokenizer_model = "mistral"
  575. name = "mistral"
  576. added_tokens_dict: dict[str, int] = {}
  577. added_tokens_list: list[str] = []
  578. def __init__(self, base_path: Path):
  579. if not _mistral_common_installed:
  580. raise ImportError(
  581. "To use MistralVocab, please install the `mistral-common` package. "
  582. "You can install it with `pip install mistral-common`."
  583. )
  584. assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
  585. assert MistralTokenizer is not None, "mistral_common is not installed"
  586. assert Tekkenizer is not None, "mistral_common is not installed"
  587. logger.info(f"Loading Mistral tokenizer from {base_path}")
  588. # Find the tokenizer files
  589. all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
  590. if get_one_valid_tokenizer_file is not None:
  591. tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
  592. else:
  593. valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
  594. if len(valid_tokenizer_files) == 0:
  595. raise ValueError(f"No tokenizer file found in the directory: {base_path}")
  596. # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
  597. if len(valid_tokenizer_files) > 1:
  598. if "tekken.json" in valid_tokenizer_files:
  599. tokenizer_file = "tekken.json"
  600. else:
  601. tokenizer_file = sorted(valid_tokenizer_files)[-1]
  602. logger.warning(
  603. f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
  604. )
  605. else:
  606. tokenizer_file = valid_tokenizer_files[0]
  607. tokenizer_file_path = base_path / tokenizer_file
  608. self.tokenizer = MistralTokenizer.from_file(
  609. tokenizer_file_path
  610. ).instruct_tokenizer.tokenizer
  611. self.tokenizer_type = (
  612. MistralTokenizerType.tekken
  613. if isinstance(self.tokenizer, Tekkenizer)
  614. else MistralTokenizerType.spm
  615. )
  616. self.vocab_size = self.tokenizer.n_words
  617. self.fname_tokenizer = tokenizer_file_path
  618. self._name = (
  619. "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
  620. )
  621. @property
  622. def tokenizer_name(self) -> str:
  623. return self._name
  624. @property
  625. def gguf_tokenizer_model(self) -> str:
  626. return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
  627. def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  628. assert SentencePieceTokenizer is not None, "mistral_common is not installed"
  629. assert isinstance(self.tokenizer, SentencePieceTokenizer), (
  630. f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
  631. )
  632. for i in range(self.tokenizer._model.vocab_size()):
  633. piece = self.tokenizer._model.IdToPiece(i)
  634. text = piece.encode("utf-8")
  635. score: float = self.tokenizer._model.GetScore(i)
  636. toktype = gguf.TokenType.NORMAL
  637. if self.tokenizer._model.IsUnknown(i):
  638. toktype = gguf.TokenType.UNKNOWN
  639. if self.tokenizer._model.IsControl(i):
  640. toktype = gguf.TokenType.CONTROL
  641. if self.tokenizer._model.IsUnused(i):
  642. toktype = gguf.TokenType.UNUSED
  643. if self.tokenizer._model.IsByte(i):
  644. toktype = gguf.TokenType.BYTE
  645. yield text, score, toktype
  646. def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  647. assert Tekkenizer is not None, "mistral_common is not installed"
  648. assert isinstance(self.tokenizer, Tekkenizer), (
  649. f"Expected Tekkenizer, got {type(self.tokenizer)}"
  650. )
  651. byte_encoder = bytes_to_unicode()
  652. for token_id in range(self.tokenizer.num_special_tokens):
  653. yield (
  654. self.tokenizer.id_to_piece(token_id).encode("utf-8"),
  655. 0,
  656. gguf.TokenType.CONTROL
  657. )
  658. for token in self.tokenizer._tekken_token2id_nospecial:
  659. yield (
  660. self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
  661. 0,
  662. gguf.TokenType.NORMAL,
  663. )
  664. def get_token_id(self, token: str) -> int:
  665. assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
  666. if self.tokenizer_type == MistralTokenizerType.spm:
  667. assert isinstance(self.tokenizer, SentencePieceTokenizer)
  668. return self.tokenizer._vocab.index(token)
  669. elif self.tokenizer_type == MistralTokenizerType.tekken:
  670. assert isinstance(self.tokenizer, Tekkenizer)
  671. return (
  672. self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
  673. )
  674. else:
  675. raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
  676. @property
  677. def bos_id(self) -> int:
  678. return self.tokenizer.bos_id
  679. @property
  680. def eos_id(self) -> int:
  681. return self.tokenizer.eos_id
  682. @property
  683. def pad_id(self) -> int:
  684. if self.tokenizer.pad_id == -1:
  685. return self.eos_id
  686. return self.tokenizer.pad_id
  687. @property
  688. def unk_id(self) -> int:
  689. return self.tokenizer.unk_id
  690. @property
  691. def bos_token(self) -> str:
  692. return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
  693. @property
  694. def eos_token(self) -> str:
  695. return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
  696. @property
  697. def pad_token(self) -> str:
  698. return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
  699. @property
  700. def unk_token(self) -> str:
  701. return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
  702. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  703. if self.tokenizer_type == MistralTokenizerType.spm:
  704. yield from self._sentencepiece_tokens()
  705. elif self.tokenizer_type == MistralTokenizerType.tekken:
  706. yield from self._tekken_tokens()
  707. else:
  708. raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
  709. @staticmethod
  710. def token_bytes_to_string(b, byte_encoder):
  711. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  712. def extract_vocab_merges_from_model(self):
  713. # Adapted from Transformers (Apache 2.0)
  714. # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
  715. assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
  716. f"Expected Tekkenizer, got {type(self.tokenizer)}"
  717. )
  718. mergeable_ranks = self.tokenizer._model._mergeable_ranks
  719. token_bytes_map = {
  720. rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
  721. }
  722. merge_pairs = []
  723. # Sort vocab by rank to ensure correct merge order
  724. for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
  725. merged_token = token_bytes_map[i]
  726. local = []
  727. for j in range(1, len(merged_token)):
  728. left = merged_token[:j]
  729. right = merged_token[j:]
  730. if (
  731. left in mergeable_ranks
  732. and right in mergeable_ranks
  733. and (left + right) in mergeable_ranks
  734. ):
  735. local.append((left, right, i))
  736. if not local:
  737. raise ValueError(
  738. f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
  739. )
  740. local = sorted(
  741. local,
  742. key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
  743. reverse=False,
  744. )
  745. merge_pairs.extend(local)
  746. merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
  747. byte_encoder = bytes_to_unicode()
  748. decoded_merge_pairs = [
  749. [
  750. self.token_bytes_to_string(val[0], byte_encoder),
  751. self.token_bytes_to_string(val[1], byte_encoder),
  752. ]
  753. for val in merge_pairs
  754. ]
  755. merges = [
  756. " ".join(
  757. [
  758. # ensure the spaces are properly encoded
  759. "".join(chr(ord(c) + 256) if c == " " else c for c in part)
  760. for part in pair
  761. ]
  762. )
  763. for pair in decoded_merge_pairs
  764. ]
  765. return merges