|
|
@@ -13,6 +13,7 @@ class SpecialVocab:
|
|
|
merges: list[str]
|
|
|
add_special_token: dict[str, bool]
|
|
|
special_token_ids: dict[str, int]
|
|
|
+ chat_template: str | None
|
|
|
|
|
|
def __init__(
|
|
|
self, path: str | os.PathLike[str], load_merges: bool = False,
|
|
|
@@ -24,6 +25,7 @@ class SpecialVocab:
|
|
|
self.n_vocab = n_vocab
|
|
|
self.load_merges = load_merges
|
|
|
self.merges = []
|
|
|
+ self.chat_template = None
|
|
|
if special_token_types is not None:
|
|
|
self.special_token_types = special_token_types
|
|
|
else:
|
|
|
@@ -67,6 +69,10 @@ class SpecialVocab:
|
|
|
if not quiet:
|
|
|
print(f'gguf: Setting add_{typ}_token to {value}')
|
|
|
add_handler(value)
|
|
|
+ if self.chat_template is not None:
|
|
|
+ if not quiet:
|
|
|
+ print(f'gguf: Setting chat_template to {self.chat_template}')
|
|
|
+ gw.add_chat_template(self.chat_template)
|
|
|
|
|
|
def _load(self, path: Path) -> None:
|
|
|
self._try_load_from_tokenizer_json(path)
|
|
|
@@ -132,6 +138,14 @@ class SpecialVocab:
|
|
|
return True
|
|
|
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
|
|
tokenizer_config = json.load(f)
|
|
|
+ chat_template = tokenizer_config.get('chat_template')
|
|
|
+ if chat_template is None or isinstance(chat_template, str):
|
|
|
+ self.chat_template = chat_template
|
|
|
+ else:
|
|
|
+ print(
|
|
|
+ f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring',
|
|
|
+ file = sys.stderr
|
|
|
+ )
|
|
|
for typ in self.special_token_types:
|
|
|
add_entry = tokenizer_config.get(f'add_{typ}_token')
|
|
|
if isinstance(add_entry, bool):
|