1
0

vocab.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from __future__ import annotations
  2. import logging
  3. import json
  4. import os
  5. from pathlib import Path
  6. from typing import Any, Callable
  7. from .gguf_writer import GGUFWriter
  8. logger = logging.getLogger(__name__)
  9. class SpecialVocab:
  10. merges: list[str]
  11. add_special_token: dict[str, bool]
  12. special_token_ids: dict[str, int]
  13. chat_template: str | None
  14. def __init__(
  15. self, path: str | os.PathLike[str], load_merges: bool = False,
  16. special_token_types: tuple[str, ...] | None = None,
  17. n_vocab: int | None = None,
  18. ):
  19. self.special_token_ids = {}
  20. self.add_special_token = {}
  21. self.n_vocab = n_vocab
  22. self.load_merges = load_merges
  23. self.merges = []
  24. self.chat_template = None
  25. if special_token_types is not None:
  26. self.special_token_types = special_token_types
  27. else:
  28. self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
  29. self._load(Path(path))
  30. def __repr__(self) -> str:
  31. return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
  32. len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
  33. )
  34. def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  35. if self.merges:
  36. if not quiet:
  37. logger.info(f'Adding {len(self.merges)} merge(s).')
  38. gw.add_token_merges(self.merges)
  39. elif self.load_merges:
  40. logger.warning('Adding merges requested but no merges found, output may be non-functional.')
  41. for typ, tokid in self.special_token_ids.items():
  42. id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  43. if id_handler is None:
  44. logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
  45. continue
  46. if not quiet:
  47. logger.info(f'Setting special token type {typ} to {tokid}')
  48. id_handler(tokid)
  49. for typ, value in self.add_special_token.items():
  50. add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
  51. if add_handler is None:
  52. logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
  53. continue
  54. if not quiet:
  55. logger.info(f'Setting add_{typ}_token to {value}')
  56. add_handler(value)
  57. if self.chat_template is not None:
  58. if not quiet:
  59. logger.info(f'Setting chat_template to {self.chat_template}')
  60. gw.add_chat_template(self.chat_template)
  61. def _load(self, path: Path) -> None:
  62. self._try_load_from_tokenizer_json(path)
  63. self._try_load_from_config_json(path)
  64. if self.load_merges and not self.merges:
  65. self._try_load_merges_txt(path)
  66. def _try_load_merges_txt(self, path: Path) -> bool:
  67. merges_file = path / 'merges.txt'
  68. if not merges_file.is_file():
  69. return False
  70. with open(merges_file, 'r', encoding = 'utf-8') as fp:
  71. first_line = next(fp, '').strip()
  72. if not first_line.startswith('#'):
  73. fp.seek(0)
  74. line_num = 0
  75. else:
  76. line_num = 1
  77. merges = []
  78. for line in fp:
  79. line_num += 1
  80. line = line.strip()
  81. if not line:
  82. continue
  83. parts = line.split(None, 3)
  84. if len(parts) != 2:
  85. logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
  86. continue
  87. merges.append(f'{parts[0]} {parts[1]}')
  88. self.merges = merges
  89. return True
  90. def _set_special_token(self, typ: str, tid: Any) -> None:
  91. if not isinstance(tid, int):
  92. return
  93. if tid < 0:
  94. raise ValueError(f'invalid value for special token type {typ}: {tid}')
  95. if self.n_vocab is None or tid < self.n_vocab:
  96. if typ in self.special_token_ids:
  97. return
  98. self.special_token_ids[typ] = tid
  99. return
  100. logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
  101. def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  102. tokenizer_file = path / 'tokenizer.json'
  103. if tokenizer_file.is_file():
  104. with open(tokenizer_file, encoding = 'utf-8') as f:
  105. tokenizer = json.load(f)
  106. if self.load_merges:
  107. merges = tokenizer.get('model', {}).get('merges')
  108. if isinstance(merges, list) and merges and isinstance(merges[0], str):
  109. self.merges = merges
  110. added_tokens = tokenizer.get('added_tokens', {})
  111. else:
  112. added_tokens = {}
  113. tokenizer_config_file = path / 'tokenizer_config.json'
  114. if not tokenizer_config_file.is_file():
  115. return True
  116. with open(tokenizer_config_file, encoding = 'utf-8') as f:
  117. tokenizer_config = json.load(f)
  118. chat_template = tokenizer_config.get('chat_template')
  119. if chat_template is None or isinstance(chat_template, (str, list)):
  120. self.chat_template = chat_template
  121. else:
  122. logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
  123. for typ in self.special_token_types:
  124. add_entry = tokenizer_config.get(f'add_{typ}_token')
  125. if isinstance(add_entry, bool):
  126. self.add_special_token[typ] = add_entry
  127. entry = tokenizer_config.get(f'{typ}_token')
  128. if isinstance(entry, str):
  129. tc_content = entry
  130. elif isinstance(entry, dict):
  131. entry_content = entry.get('content')
  132. if not isinstance(entry_content, str):
  133. continue
  134. tc_content = entry_content
  135. else:
  136. continue
  137. # We only need the first match here.
  138. maybe_token_id = next(
  139. (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
  140. None,
  141. )
  142. self._set_special_token(typ, maybe_token_id)
  143. return True
  144. def _try_load_from_config_json(self, path: Path) -> bool:
  145. config_file = path / 'config.json'
  146. if not config_file.is_file():
  147. return False
  148. with open(config_file, encoding = 'utf-8') as f:
  149. config = json.load(f)
  150. for typ in self.special_token_types:
  151. self._set_special_token(typ, config.get(f'{typ}_token_id'))
  152. return True