vocab.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import sys
  5. from pathlib import Path
  6. from typing import Any, Callable
  7. from .gguf_writer import GGUFWriter
  8. class SpecialVocab:
  9. merges: list[str]
  10. add_special_token: dict[str, bool]
  11. special_token_ids: dict[str, int]
  12. chat_template: str | None
  13. def __init__(
  14. self, path: str | os.PathLike[str], load_merges: bool = False,
  15. special_token_types: tuple[str, ...] | None = None,
  16. n_vocab: int | None = None,
  17. ):
  18. self.special_token_ids = {}
  19. self.add_special_token = {}
  20. self.n_vocab = n_vocab
  21. self.load_merges = load_merges
  22. self.merges = []
  23. self.chat_template = None
  24. if special_token_types is not None:
  25. self.special_token_types = special_token_types
  26. else:
  27. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  28. self._load(Path(path))
  29. def __repr__(self) -> str:
  30. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  31. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  32. )
  33. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  34. if self.merges:
  35. if not quiet:
  36. print(f'gguf: Adding {len(self.merges)} merge(s).')
  37. gw.add_token_merges(self.merges)
  38. elif self.load_merges:
  39. print(
  40. 'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
  41. file = sys.stderr,
  42. )
  43. for typ, tokid in self.special_token_ids.items():
  44. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  45. if id_handler is None:
  46. print(
  47. f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping',
  48. file = sys.stderr,
  49. )
  50. continue
  51. if not quiet:
  52. print(f'gguf: Setting special token type {typ} to {tokid}')
  53. id_handler(tokid)
  54. for typ, value in self.add_special_token.items():
  55. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  56. if add_handler is None:
  57. print(
  58. f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping',
  59. file = sys.stderr,
  60. )
  61. continue
  62. if not quiet:
  63. print(f'gguf: Setting add_{typ}_token to {value}')
  64. add_handler(value)
  65. if self.chat_template is not None:
  66. if not quiet:
  67. print(f'gguf: Setting chat_template to {self.chat_template}')
  68. gw.add_chat_template(self.chat_template)
  69. def _load(self, path: Path) -> None:
  70. self._try_load_from_tokenizer_json(path)
  71. self._try_load_from_config_json(path)
  72. if self.load_merges and not self.merges:
  73. self._try_load_merges_txt(path)
  74. def _try_load_merges_txt(self, path: Path) -> bool:
  75. merges_file = path / 'merges.txt'
  76. if not merges_file.is_file():
  77. return False
  78. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  79. first_line = next(fp, '').strip()
  80. if not first_line.startswith('#'):
  81. fp.seek(0)
  82. line_num = 0
  83. else:
  84. line_num = 1
  85. merges = []
  86. for line in fp:
  87. line_num += 1
  88. line = line.strip()
  89. if not line:
  90. continue
  91. parts = line.split(None, 3)
  92. if len(parts) != 2:
  93. print(
  94. f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
  95. file = sys.stderr,
  96. )
  97. continue
  98. merges.append(f'{parts[0]} {parts[1]}')
  99. self.merges = merges
  100. return True
  101. def _set_special_token(self, typ: str, tid: Any) -> None:
  102. if not isinstance(tid, int):
  103. return
  104. if tid < 0:
  105. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  106. if self.n_vocab is None or tid < self.n_vocab:
  107. if typ in self.special_token_ids:
  108. return
  109. self.special_token_ids[typ] = tid
  110. return
  111. print(
  112. f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
  113. file = sys.stderr,
  114. )
  115. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  116. tokenizer_file = path / 'tokenizer.json'
  117. if tokenizer_file.is_file():
  118. with open(tokenizer_file, encoding = 'utf-8') as f:
  119. tokenizer = json.load(f)
  120. if self.load_merges:
  121. merges = tokenizer.get('model', {}).get('merges')
  122. if isinstance(merges, list) and merges and isinstance(merges[0], str):
  123. self.merges = merges
  124. added_tokens = tokenizer.get('added_tokens', {})
  125. else:
  126. added_tokens = {}
  127. tokenizer_config_file = path / 'tokenizer_config.json'
  128. if not tokenizer_config_file.is_file():
  129. return True
  130. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  131. tokenizer_config = json.load(f)
  132. chat_template = tokenizer_config.get('chat_template')
  133. if chat_template is None or isinstance(chat_template, (str, list)):
  134. self.chat_template = chat_template
  135. else:
  136. print(
  137. f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring',
  138. file = sys.stderr
  139. )
  140. for typ in self.special_token_types:
  141. add_entry = tokenizer_config.get(f'add_{typ}_token')
  142. if isinstance(add_entry, bool):
  143. self.add_special_token[typ] = add_entry
  144. entry = tokenizer_config.get(f'{typ}_token')
  145. if isinstance(entry, str):
  146. tc_content = entry
  147. elif isinstance(entry, dict):
  148. entry_content = entry.get('content')
  149. if not isinstance(entry_content, str):
  150. continue
  151. tc_content = entry_content
  152. else:
  153. continue
  154. # We only need the first match here.
  155. maybe_token_id = next(
  156. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  157. None,
  158. )
  159. self._set_special_token(typ, maybe_token_id)
  160. return True
  161. def _try_load_from_config_json(self, path: Path) -> bool:
  162. config_file = path / 'config.json'
  163. if not config_file.is_file():
  164. return False
  165. with open(config_file, encoding = 'utf-8') as f:
  166. config = json.load(f)
  167. for typ in self.special_token_types:
  168. self._set_special_token(typ, config.get(f'{typ}_token_id'))
  169. return True