vocab.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from __future__ import annotations
  2. import re
  3. import logging
  4. import json
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
  8. try:
  9. from sentencepiece import SentencePieceProcessor
  10. except ImportError:
  11. SentencePieceProcessor = None
  12. import gguf
  13. from .gguf_writer import GGUFWriter
  14. logger = logging.getLogger(__name__)
  15. class SpecialVocab:
  16. merges: list[str]
  17. add_special_token: dict[str, bool]
  18. special_token_ids: dict[str, int]
  19. chat_template: str | Sequence[Mapping[str, str]] | None
  20. def __init__(
  21. self, path: str | os.PathLike[str], load_merges: bool = False,
  22. special_token_types: Iterable[str] | None = None,
  23. n_vocab: int | None = None,
  24. ):
  25. self.special_token_ids = {}
  26. self.add_special_token = {}
  27. self.n_vocab = n_vocab
  28. self.load_merges = load_merges
  29. self.merges = []
  30. self.chat_template = None
  31. if special_token_types is not None:
  32. self.special_token_types = special_token_types
  33. else:
  34. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  35. self._load(Path(path))
  36. def __repr__(self) -> str:
  37. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  38. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  39. )
  40. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  41. if self.merges:
  42. if not quiet:
  43. logger.info(f'Adding {len(self.merges)} merge(s).')
  44. gw.add_token_merges(self.merges)
  45. elif self.load_merges:
  46. logger.warning('Adding merges requested but no merges found, output may be non-functional.')
  47. for typ, tokid in self.special_token_ids.items():
  48. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  49. if id_handler is None:
  50. logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
  51. continue
  52. if not quiet:
  53. logger.info(f'Setting special token type {typ} to {tokid}')
  54. id_handler(tokid)
  55. for typ, value in self.add_special_token.items():
  56. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  57. if add_handler is None:
  58. logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
  59. continue
  60. if not quiet:
  61. logger.info(f'Setting add_{typ}_token to {value}')
  62. add_handler(value)
  63. if self.chat_template is not None:
  64. if not quiet:
  65. logger.info(f'Setting chat_template to {self.chat_template}')
  66. gw.add_chat_template(self.chat_template)
  67. def _load(self, path: Path) -> None:
  68. self._try_load_from_tokenizer_json(path)
  69. self._try_load_from_config_json(path)
  70. if self.load_merges and not self.merges:
  71. self._try_load_merges_txt(path)
  72. def _try_load_merges_txt(self, path: Path) -> bool:
  73. merges_file = path / 'merges.txt'
  74. if not merges_file.is_file():
  75. return False
  76. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  77. first_line = next(fp, '').strip()
  78. if not first_line.startswith('#'):
  79. fp.seek(0)
  80. line_num = 0
  81. else:
  82. line_num = 1
  83. merges = []
  84. for line in fp:
  85. line_num += 1
  86. line = line.strip()
  87. if not line:
  88. continue
  89. parts = line.split(None, 3)
  90. if len(parts) != 2:
  91. logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
  92. continue
  93. merges.append(f'{parts[0]} {parts[1]}')
  94. self.merges = merges
  95. return True
  96. def _set_special_token(self, typ: str, tid: Any) -> None:
  97. if not isinstance(tid, int):
  98. return
  99. if tid < 0:
  100. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  101. if self.n_vocab is None or tid < self.n_vocab:
  102. if typ in self.special_token_ids:
  103. return
  104. self.special_token_ids[typ] = tid
  105. return
  106. logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
  107. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  108. tokenizer_file = path / 'tokenizer.json'
  109. if tokenizer_file.is_file():
  110. with open(tokenizer_file, encoding = 'utf-8') as f:
  111. tokenizer = json.load(f)
  112. if self.load_merges:
  113. merges = tokenizer.get('model', {}).get('merges')
  114. if isinstance(merges, list) and merges:
  115. if isinstance(merges[0], str):
  116. self.merges = merges
  117. elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
  118. # New format since transformers 4.45 to support spaces in merges
  119. # ref: https://github.com/ggml-org/llama.cpp/issues/9692
  120. # TODO: internally store as the new format instead of converting to old
  121. if any(' ' in s for pair in merges for s in pair):
  122. logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
  123. self.merges = [
  124. ' '.join(
  125. [
  126. # ensure the spaces are properly encoded
  127. ''.join(
  128. chr(ord(c) + 256) if c == ' ' else c
  129. for c in part
  130. )
  131. for part in pair
  132. ]
  133. )
  134. for pair in merges
  135. ]
  136. else:
  137. raise ValueError("Unknown tokenizer merges format")
  138. added_tokens = tokenizer.get('added_tokens', {})
  139. else:
  140. added_tokens = {}
  141. tokenizer_config_file = path / 'tokenizer_config.json'
  142. if not tokenizer_config_file.is_file():
  143. return True
  144. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  145. tokenizer_config = json.load(f)
  146. chat_template_alt = None
  147. chat_template_file = path / 'chat_template.json'
  148. if chat_template_file.is_file():
  149. with open(chat_template_file, encoding = 'utf-8') as f:
  150. chat_template_alt = json.load(f).get('chat_template')
  151. chat_template = tokenizer_config.get('chat_template', chat_template_alt)
  152. if chat_template is None or isinstance(chat_template, (str, list)):
  153. self.chat_template = chat_template
  154. else:
  155. logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
  156. for typ in self.special_token_types:
  157. add_entry = tokenizer_config.get(f'add_{typ}_token')
  158. if isinstance(add_entry, bool):
  159. self.add_special_token[typ] = add_entry
  160. entry = tokenizer_config.get(f'{typ}_token')
  161. if isinstance(entry, str):
  162. tc_content = entry
  163. elif isinstance(entry, dict):
  164. entry_content = entry.get('content')
  165. if not isinstance(entry_content, str):
  166. continue
  167. tc_content = entry_content
  168. else:
  169. continue
  170. # We only need the first match here.
  171. maybe_token_id = next(
  172. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  173. None,
  174. )
  175. self._set_special_token(typ, maybe_token_id)
  176. return True
  177. def _try_load_from_config_json(self, path: Path) -> bool:
  178. config_file = path / 'config.json'
  179. if not config_file.is_file():
  180. return False
  181. with open(config_file, encoding = 'utf-8') as f:
  182. config = json.load(f)
  183. for typ in self.special_token_types:
  184. self._set_special_token(typ, config.get(f'{typ}_token_id'))
  185. return True
  186. @runtime_checkable
  187. class BaseVocab(Protocol):
  188. tokenizer_model: ClassVar[str]
  189. name: ClassVar[str]
  190. @runtime_checkable
  191. class Vocab(BaseVocab, Protocol):
  192. vocab_size: int
  193. added_tokens_dict: dict[str, int]
  194. added_tokens_list: list[str]
  195. fname_tokenizer: Path
  196. def __init__(self, base_path: Path): ...
  197. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
  198. class NoVocab(BaseVocab):
  199. tokenizer_model = "no_vocab"
  200. name = "no_vocab"
  201. def __repr__(self) -> str:
  202. return "<NoVocab for a model without integrated vocabulary>"
  203. class BpeVocab(Vocab):
  204. tokenizer_model = "gpt2"
  205. name = "bpe"
  206. def __init__(self, base_path: Path):
  207. added_tokens: dict[str, int] = {}
  208. if (fname_tokenizer := base_path / 'vocab.json').exists():
  209. # "slow" tokenizer
  210. with open(fname_tokenizer, encoding="utf-8") as f:
  211. self.vocab = json.load(f)
  212. try:
  213. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
  214. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  215. added_tokens = json.load(f)
  216. except FileNotFoundError:
  217. pass
  218. else:
  219. # "fast" tokenizer
  220. fname_tokenizer = base_path / 'tokenizer.json'
  221. # if this fails, FileNotFoundError propagates to caller
  222. with open(fname_tokenizer, encoding="utf-8") as f:
  223. tokenizer_json = json.load(f)
  224. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  225. if (
  226. tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
  227. or tokenizer_json['decoder']['type'] != 'ByteLevel'
  228. ):
  229. raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
  230. self.vocab = tokenizer_model["vocab"]
  231. if (added := tokenizer_json.get('added_tokens')) is not None:
  232. # Added tokens here can be duplicates of the main vocabulary.
  233. added_tokens = {item['content']: item['id']
  234. for item in added
  235. if item['content'] not in self.vocab}
  236. vocab_size = len(self.vocab)
  237. expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
  238. actual_ids = sorted(added_tokens.values())
  239. if expected_ids != actual_ids:
  240. expected_end_id = vocab_size + len(actual_ids) - 1
  241. raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
  242. f"{vocab_size} - {expected_end_id}; got {actual_ids}")
  243. items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
  244. self.added_tokens_dict = added_tokens
  245. self.added_tokens_list = [text for (text, idx) in items]
  246. self.vocab_size_base = vocab_size
  247. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  248. self.fname_tokenizer = fname_tokenizer
  249. def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  250. reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
  251. for i, _ in enumerate(self.vocab):
  252. yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
  253. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  254. for text in self.added_tokens_list:
  255. score = -1000.0
  256. yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
  257. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  258. yield from self.bpe_tokens()
  259. yield from self.added_tokens()
  260. def __repr__(self) -> str:
  261. return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  262. class SentencePieceVocab(Vocab):
  263. tokenizer_model = "llama"
  264. name = "spm"
  265. def __init__(self, base_path: Path):
  266. if SentencePieceProcessor is None:
  267. raise RuntimeError("sentencepiece is not installed")
  268. added_tokens: dict[str, int] = {}
  269. if (fname_tokenizer := base_path / 'tokenizer.model').exists():
  270. # normal location
  271. try:
  272. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  273. added_tokens = json.load(f)
  274. except FileNotFoundError:
  275. pass
  276. elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
  277. # not found in alternate location either
  278. raise FileNotFoundError('Cannot find tokenizer.model')
  279. self.sentencepiece_tokenizer = SentencePieceProcessor()
  280. self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
  281. vocab_size = self.sentencepiece_tokenizer.vocab_size()
  282. new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
  283. expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
  284. actual_new_ids = sorted(new_tokens.keys())
  285. if expected_new_ids != actual_new_ids:
  286. raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
  287. # Token pieces that were added to the base vocabulary.
  288. self.added_tokens_dict = added_tokens
  289. self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
  290. self.vocab_size_base = vocab_size
  291. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  292. self.fname_tokenizer = fname_tokenizer
  293. def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  294. tokenizer = self.sentencepiece_tokenizer
  295. for i in range(tokenizer.vocab_size()):
  296. piece = tokenizer.IdToPiece(i)
  297. text = piece.encode("utf-8")
  298. score: float = tokenizer.GetScore(i)
  299. toktype = gguf.TokenType.NORMAL
  300. if tokenizer.IsUnknown(i):
  301. toktype = gguf.TokenType.UNKNOWN
  302. if tokenizer.IsControl(i):
  303. toktype = gguf.TokenType.CONTROL
  304. # NOTE: I think added_tokens are user defined.
  305. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
  306. # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
  307. if tokenizer.IsUnused(i):
  308. toktype = gguf.TokenType.UNUSED
  309. if tokenizer.IsByte(i):
  310. toktype = gguf.TokenType.BYTE
  311. yield text, score, toktype
  312. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  313. for text in self.added_tokens_list:
  314. score = -1000.0
  315. yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
  316. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  317. yield from self.sentencepiece_tokens()
  318. yield from self.added_tokens()
  319. def __repr__(self) -> str:
  320. return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  321. class LlamaHfVocab(Vocab):
  322. tokenizer_model = "llama"
  323. name = "hfft"
  324. def __init__(self, base_path: Path):
  325. fname_tokenizer = base_path / 'tokenizer.json'
  326. # if this fails, FileNotFoundError propagates to caller
  327. with open(fname_tokenizer, encoding='utf-8') as f:
  328. tokenizer_json = json.load(f)
  329. # pre-check so we know if we need transformers
  330. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  331. is_llama3 = (
  332. tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
  333. and not tokenizer_model.get('byte_fallback', True)
  334. )
  335. if is_llama3:
  336. raise TypeError('Llama 3 must be converted with BpeVocab')
  337. if not is_llama3 and (
  338. tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
  339. or tokenizer_json['decoder']['type'] != 'Sequence'
  340. ):
  341. raise FileNotFoundError('Cannot find Llama BPE tokenizer')
  342. try:
  343. from transformers import AutoTokenizer
  344. except ImportError as e:
  345. raise ImportError(
  346. "To use LlamaHfVocab, please install the `transformers` package. "
  347. "You can install it with `pip install transformers`."
  348. ) from e
  349. # Allow the tokenizer to default to slow or fast versions.
  350. # Explicitly set tokenizer to use local paths.
  351. self.tokenizer = AutoTokenizer.from_pretrained(
  352. base_path,
  353. cache_dir=base_path,
  354. local_files_only=True,
  355. )
  356. assert self.tokenizer.is_fast # assume tokenizer.json is used
  357. # Initialize lists and dictionaries for added tokens
  358. self.added_tokens_list = []
  359. self.added_tokens_dict = dict()
  360. self.added_tokens_ids = set()
  361. # Process added tokens
  362. for tok, tokidx in sorted(
  363. self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
  364. ):
  365. # Only consider added tokens that are not in the base vocabulary
  366. if tokidx >= self.tokenizer.vocab_size:
  367. self.added_tokens_list.append(tok)
  368. self.added_tokens_dict[tok] = tokidx
  369. self.added_tokens_ids.add(tokidx)
  370. # Store special tokens and their IDs
  371. self.specials = {
  372. tok: self.tokenizer.get_vocab()[tok]
  373. for tok in self.tokenizer.all_special_tokens
  374. }
  375. self.special_ids = set(self.tokenizer.all_special_ids)
  376. # Set vocabulary sizes
  377. self.vocab_size_base = self.tokenizer.vocab_size
  378. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  379. self.fname_tokenizer = fname_tokenizer
  380. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  381. reverse_vocab = {
  382. id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
  383. }
  384. for token_id in range(self.vocab_size_base):
  385. # Skip processing added tokens here
  386. if token_id in self.added_tokens_ids:
  387. continue
  388. # Convert token text to bytes
  389. token_text = reverse_vocab[token_id].encode("utf-8")
  390. # Yield token text, score, and type
  391. yield token_text, self.get_token_score(token_id), self.get_token_type(
  392. token_id, token_text, self.special_ids # Reuse already stored special IDs
  393. )
  394. def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
  395. # Special case for byte tokens
  396. if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  397. return gguf.TokenType.BYTE
  398. # Determine token type based on whether it's a special token
  399. return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
  400. def get_token_score(self, token_id: int) -> float:
  401. # Placeholder for actual logic to determine the token's score
  402. # This needs to be implemented based on specific requirements
  403. return -1000.0 # Default score
  404. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  405. for text in self.added_tokens_list:
  406. if text in self.specials:
  407. toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
  408. score = self.get_token_score(self.specials[text])
  409. else:
  410. toktype = gguf.TokenType.USER_DEFINED
  411. score = -1000.0
  412. yield text.encode("utf-8"), score, toktype
  413. def has_newline_token(self):
  414. return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
  415. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  416. yield from self.hf_tokens()
  417. yield from self.added_tokens()
  418. def __repr__(self) -> str:
  419. return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"