vocab.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. from __future__ import annotations
  2. import re
  3. import logging
  4. import json
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
  8. from sentencepiece import SentencePieceProcessor
  9. import gguf
  10. from .gguf_writer import GGUFWriter
  11. logger = logging.getLogger(__name__)
  12. class SpecialVocab:
  13. merges: list[str]
  14. add_special_token: dict[str, bool]
  15. special_token_ids: dict[str, int]
  16. chat_template: str | Sequence[Mapping[str, str]] | None
  17. def __init__(
  18. self, path: str | os.PathLike[str], load_merges: bool = False,
  19. special_token_types: Iterable[str] | None = None,
  20. n_vocab: int | None = None,
  21. ):
  22. self.special_token_ids = {}
  23. self.add_special_token = {}
  24. self.n_vocab = n_vocab
  25. self.load_merges = load_merges
  26. self.merges = []
  27. self.chat_template = None
  28. if special_token_types is not None:
  29. self.special_token_types = special_token_types
  30. else:
  31. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  32. self._load(Path(path))
  33. def __repr__(self) -> str:
  34. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  35. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  36. )
  37. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  38. if self.merges:
  39. if not quiet:
  40. logger.info(f'Adding {len(self.merges)} merge(s).')
  41. gw.add_token_merges(self.merges)
  42. elif self.load_merges:
  43. logger.warning('Adding merges requested but no merges found, output may be non-functional.')
  44. for typ, tokid in self.special_token_ids.items():
  45. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  46. if id_handler is None:
  47. logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
  48. continue
  49. if not quiet:
  50. logger.info(f'Setting special token type {typ} to {tokid}')
  51. id_handler(tokid)
  52. for typ, value in self.add_special_token.items():
  53. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  54. if add_handler is None:
  55. logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
  56. continue
  57. if not quiet:
  58. logger.info(f'Setting add_{typ}_token to {value}')
  59. add_handler(value)
  60. if self.chat_template is not None:
  61. if not quiet:
  62. logger.info(f'Setting chat_template to {self.chat_template}')
  63. gw.add_chat_template(self.chat_template)
  64. def _load(self, path: Path) -> None:
  65. self._try_load_from_tokenizer_json(path)
  66. self._try_load_from_config_json(path)
  67. if self.load_merges and not self.merges:
  68. self._try_load_merges_txt(path)
  69. def _try_load_merges_txt(self, path: Path) -> bool:
  70. merges_file = path / 'merges.txt'
  71. if not merges_file.is_file():
  72. return False
  73. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  74. first_line = next(fp, '').strip()
  75. if not first_line.startswith('#'):
  76. fp.seek(0)
  77. line_num = 0
  78. else:
  79. line_num = 1
  80. merges = []
  81. for line in fp:
  82. line_num += 1
  83. line = line.strip()
  84. if not line:
  85. continue
  86. parts = line.split(None, 3)
  87. if len(parts) != 2:
  88. logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
  89. continue
  90. merges.append(f'{parts[0]} {parts[1]}')
  91. self.merges = merges
  92. return True
  93. def _set_special_token(self, typ: str, tid: Any) -> None:
  94. if not isinstance(tid, int):
  95. return
  96. if tid < 0:
  97. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  98. if self.n_vocab is None or tid < self.n_vocab:
  99. if typ in self.special_token_ids:
  100. return
  101. self.special_token_ids[typ] = tid
  102. return
  103. logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
  104. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  105. tokenizer_file = path / 'tokenizer.json'
  106. if tokenizer_file.is_file():
  107. with open(tokenizer_file, encoding = 'utf-8') as f:
  108. tokenizer = json.load(f)
  109. if self.load_merges:
  110. merges = tokenizer.get('model', {}).get('merges')
  111. if isinstance(merges, list) and merges:
  112. if isinstance(merges[0], str):
  113. self.merges = merges
  114. elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
  115. # New format since transformers 4.45 to support spaces in merges
  116. # ref: https://github.com/ggml-org/llama.cpp/issues/9692
  117. # TODO: internally store as the new format instead of converting to old
  118. if any(' ' in s for pair in merges for s in pair):
  119. logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
  120. self.merges = [
  121. ' '.join(
  122. [
  123. # ensure the spaces are properly encoded
  124. ''.join(
  125. chr(ord(c) + 256) if c == ' ' else c
  126. for c in part
  127. )
  128. for part in pair
  129. ]
  130. )
  131. for pair in merges
  132. ]
  133. else:
  134. raise ValueError("Unknown tokenizer merges format")
  135. added_tokens = tokenizer.get('added_tokens', {})
  136. else:
  137. added_tokens = {}
  138. tokenizer_config_file = path / 'tokenizer_config.json'
  139. if not tokenizer_config_file.is_file():
  140. return True
  141. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  142. tokenizer_config = json.load(f)
  143. chat_template = tokenizer_config.get('chat_template')
  144. if chat_template is None or isinstance(chat_template, (str, list)):
  145. self.chat_template = chat_template
  146. else:
  147. logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
  148. for typ in self.special_token_types:
  149. add_entry = tokenizer_config.get(f'add_{typ}_token')
  150. if isinstance(add_entry, bool):
  151. self.add_special_token[typ] = add_entry
  152. entry = tokenizer_config.get(f'{typ}_token')
  153. if isinstance(entry, str):
  154. tc_content = entry
  155. elif isinstance(entry, dict):
  156. entry_content = entry.get('content')
  157. if not isinstance(entry_content, str):
  158. continue
  159. tc_content = entry_content
  160. else:
  161. continue
  162. # We only need the first match here.
  163. maybe_token_id = next(
  164. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  165. None,
  166. )
  167. self._set_special_token(typ, maybe_token_id)
  168. return True
  169. def _try_load_from_config_json(self, path: Path) -> bool:
  170. config_file = path / 'config.json'
  171. if not config_file.is_file():
  172. return False
  173. with open(config_file, encoding = 'utf-8') as f:
  174. config = json.load(f)
  175. for typ in self.special_token_types:
  176. self._set_special_token(typ, config.get(f'{typ}_token_id'))
  177. return True
  178. @runtime_checkable
  179. class BaseVocab(Protocol):
  180. tokenizer_model: ClassVar[str]
  181. name: ClassVar[str]
  182. @runtime_checkable
  183. class Vocab(BaseVocab, Protocol):
  184. vocab_size: int
  185. added_tokens_dict: dict[str, int]
  186. added_tokens_list: list[str]
  187. fname_tokenizer: Path
  188. def __init__(self, base_path: Path): ...
  189. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
  190. class NoVocab(BaseVocab):
  191. tokenizer_model = "no_vocab"
  192. name = "no_vocab"
  193. def __repr__(self) -> str:
  194. return "<NoVocab for a model without integrated vocabulary>"
  195. class BpeVocab(Vocab):
  196. tokenizer_model = "gpt2"
  197. name = "bpe"
  198. def __init__(self, base_path: Path):
  199. added_tokens: dict[str, int] = {}
  200. if (fname_tokenizer := base_path / 'vocab.json').exists():
  201. # "slow" tokenizer
  202. with open(fname_tokenizer, encoding="utf-8") as f:
  203. self.vocab = json.load(f)
  204. try:
  205. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
  206. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  207. added_tokens = json.load(f)
  208. except FileNotFoundError:
  209. pass
  210. else:
  211. # "fast" tokenizer
  212. fname_tokenizer = base_path / 'tokenizer.json'
  213. # if this fails, FileNotFoundError propagates to caller
  214. with open(fname_tokenizer, encoding="utf-8") as f:
  215. tokenizer_json = json.load(f)
  216. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  217. if (
  218. tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
  219. or tokenizer_json['decoder']['type'] != 'ByteLevel'
  220. ):
  221. raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
  222. self.vocab = tokenizer_model["vocab"]
  223. if (added := tokenizer_json.get('added_tokens')) is not None:
  224. # Added tokens here can be duplicates of the main vocabulary.
  225. added_tokens = {item['content']: item['id']
  226. for item in added
  227. if item['content'] not in self.vocab}
  228. vocab_size = len(self.vocab)
  229. expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
  230. actual_ids = sorted(added_tokens.values())
  231. if expected_ids != actual_ids:
  232. expected_end_id = vocab_size + len(actual_ids) - 1
  233. raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
  234. f"{vocab_size} - {expected_end_id}; got {actual_ids}")
  235. items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
  236. self.added_tokens_dict = added_tokens
  237. self.added_tokens_list = [text for (text, idx) in items]
  238. self.vocab_size_base = vocab_size
  239. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  240. self.fname_tokenizer = fname_tokenizer
  241. def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  242. reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
  243. for i, _ in enumerate(self.vocab):
  244. yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
  245. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  246. for text in self.added_tokens_list:
  247. score = -1000.0
  248. yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
  249. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  250. yield from self.bpe_tokens()
  251. yield from self.added_tokens()
  252. def __repr__(self) -> str:
  253. return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  254. class SentencePieceVocab(Vocab):
  255. tokenizer_model = "llama"
  256. name = "spm"
  257. def __init__(self, base_path: Path):
  258. added_tokens: dict[str, int] = {}
  259. if (fname_tokenizer := base_path / 'tokenizer.model').exists():
  260. # normal location
  261. try:
  262. with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
  263. added_tokens = json.load(f)
  264. except FileNotFoundError:
  265. pass
  266. elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
  267. # not found in alternate location either
  268. raise FileNotFoundError('Cannot find tokenizer.model')
  269. self.sentencepiece_tokenizer = SentencePieceProcessor()
  270. self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
  271. vocab_size = self.sentencepiece_tokenizer.vocab_size()
  272. new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
  273. expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
  274. actual_new_ids = sorted(new_tokens.keys())
  275. if expected_new_ids != actual_new_ids:
  276. raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
  277. # Token pieces that were added to the base vocabulary.
  278. self.added_tokens_dict = added_tokens
  279. self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
  280. self.vocab_size_base = vocab_size
  281. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  282. self.fname_tokenizer = fname_tokenizer
  283. def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  284. tokenizer = self.sentencepiece_tokenizer
  285. for i in range(tokenizer.vocab_size()):
  286. piece = tokenizer.IdToPiece(i)
  287. text = piece.encode("utf-8")
  288. score: float = tokenizer.GetScore(i)
  289. toktype = gguf.TokenType.NORMAL
  290. if tokenizer.IsUnknown(i):
  291. toktype = gguf.TokenType.UNKNOWN
  292. if tokenizer.IsControl(i):
  293. toktype = gguf.TokenType.CONTROL
  294. # NOTE: I think added_tokens are user defined.
  295. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
  296. # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
  297. if tokenizer.IsUnused(i):
  298. toktype = gguf.TokenType.UNUSED
  299. if tokenizer.IsByte(i):
  300. toktype = gguf.TokenType.BYTE
  301. yield text, score, toktype
  302. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  303. for text in self.added_tokens_list:
  304. score = -1000.0
  305. yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
  306. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  307. yield from self.sentencepiece_tokens()
  308. yield from self.added_tokens()
  309. def __repr__(self) -> str:
  310. return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
  311. class LlamaHfVocab(Vocab):
  312. tokenizer_model = "llama"
  313. name = "hfft"
  314. def __init__(self, base_path: Path):
  315. fname_tokenizer = base_path / 'tokenizer.json'
  316. # if this fails, FileNotFoundError propagates to caller
  317. with open(fname_tokenizer, encoding='utf-8') as f:
  318. tokenizer_json = json.load(f)
  319. # pre-check so we know if we need transformers
  320. tokenizer_model: dict[str, Any] = tokenizer_json['model']
  321. is_llama3 = (
  322. tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
  323. and not tokenizer_model.get('byte_fallback', True)
  324. )
  325. if is_llama3:
  326. raise TypeError('Llama 3 must be converted with BpeVocab')
  327. if not is_llama3 and (
  328. tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
  329. or tokenizer_json['decoder']['type'] != 'Sequence'
  330. ):
  331. raise FileNotFoundError('Cannot find Llama BPE tokenizer')
  332. try:
  333. from transformers import AutoTokenizer
  334. except ImportError as e:
  335. raise ImportError(
  336. "To use LlamaHfVocab, please install the `transformers` package. "
  337. "You can install it with `pip install transformers`."
  338. ) from e
  339. # Allow the tokenizer to default to slow or fast versions.
  340. # Explicitly set tokenizer to use local paths.
  341. self.tokenizer = AutoTokenizer.from_pretrained(
  342. base_path,
  343. cache_dir=base_path,
  344. local_files_only=True,
  345. )
  346. assert self.tokenizer.is_fast # assume tokenizer.json is used
  347. # Initialize lists and dictionaries for added tokens
  348. self.added_tokens_list = []
  349. self.added_tokens_dict = dict()
  350. self.added_tokens_ids = set()
  351. # Process added tokens
  352. for tok, tokidx in sorted(
  353. self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
  354. ):
  355. # Only consider added tokens that are not in the base vocabulary
  356. if tokidx >= self.tokenizer.vocab_size:
  357. self.added_tokens_list.append(tok)
  358. self.added_tokens_dict[tok] = tokidx
  359. self.added_tokens_ids.add(tokidx)
  360. # Store special tokens and their IDs
  361. self.specials = {
  362. tok: self.tokenizer.get_vocab()[tok]
  363. for tok in self.tokenizer.all_special_tokens
  364. }
  365. self.special_ids = set(self.tokenizer.all_special_ids)
  366. # Set vocabulary sizes
  367. self.vocab_size_base = self.tokenizer.vocab_size
  368. self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
  369. self.fname_tokenizer = fname_tokenizer
  370. def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  371. reverse_vocab = {
  372. id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
  373. }
  374. for token_id in range(self.vocab_size_base):
  375. # Skip processing added tokens here
  376. if token_id in self.added_tokens_ids:
  377. continue
  378. # Convert token text to bytes
  379. token_text = reverse_vocab[token_id].encode("utf-8")
  380. # Yield token text, score, and type
  381. yield token_text, self.get_token_score(token_id), self.get_token_type(
  382. token_id, token_text, self.special_ids # Reuse already stored special IDs
  383. )
  384. def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
  385. # Special case for byte tokens
  386. if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  387. return gguf.TokenType.BYTE
  388. # Determine token type based on whether it's a special token
  389. return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
  390. def get_token_score(self, token_id: int) -> float:
  391. # Placeholder for actual logic to determine the token's score
  392. # This needs to be implemented based on specific requirements
  393. return -1000.0 # Default score
  394. def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  395. for text in self.added_tokens_list:
  396. if text in self.specials:
  397. toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
  398. score = self.get_token_score(self.specials[text])
  399. else:
  400. toktype = gguf.TokenType.USER_DEFINED
  401. score = -1000.0
  402. yield text.encode("utf-8"), score, toktype
  403. def has_newline_token(self):
  404. return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
  405. def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
  406. yield from self.hf_tokens()
  407. yield from self.added_tokens()
  408. def __repr__(self) -> str:
  409. return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"