| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487 |
- from __future__ import annotations
- import re
- import logging
- import json
- import os
- from pathlib import Path
- from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
- from sentencepiece import SentencePieceProcessor
- import gguf
- from .gguf_writer import GGUFWriter
- logger = logging.getLogger(__name__)
- class SpecialVocab:
- merges: list[str]
- add_special_token: dict[str, bool]
- special_token_ids: dict[str, int]
- chat_template: str | Sequence[Mapping[str, str]] | None
- def __init__(
- self, path: str | os.PathLike[str], load_merges: bool = False,
- special_token_types: Iterable[str] | None = None,
- n_vocab: int | None = None,
- ):
- self.special_token_ids = {}
- self.add_special_token = {}
- self.n_vocab = n_vocab
- self.load_merges = load_merges
- self.merges = []
- self.chat_template = None
- if special_token_types is not None:
- self.special_token_types = special_token_types
- else:
- self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
- self._load(Path(path))
- def __repr__(self) -> str:
- return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
- len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
- )
- def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
- if self.merges:
- if not quiet:
- logger.info(f'Adding {len(self.merges)} merge(s).')
- gw.add_token_merges(self.merges)
- elif self.load_merges:
- logger.warning('Adding merges requested but no merges found, output may be non-functional.')
- for typ, tokid in self.special_token_ids.items():
- id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
- if id_handler is None:
- logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
- continue
- if not quiet:
- logger.info(f'Setting special token type {typ} to {tokid}')
- id_handler(tokid)
- for typ, value in self.add_special_token.items():
- add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
- if add_handler is None:
- logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
- continue
- if not quiet:
- logger.info(f'Setting add_{typ}_token to {value}')
- add_handler(value)
- if self.chat_template is not None:
- if not quiet:
- logger.info(f'Setting chat_template to {self.chat_template}')
- gw.add_chat_template(self.chat_template)
- def _load(self, path: Path) -> None:
- self._try_load_from_tokenizer_json(path)
- self._try_load_from_config_json(path)
- if self.load_merges and not self.merges:
- self._try_load_merges_txt(path)
- def _try_load_merges_txt(self, path: Path) -> bool:
- merges_file = path / 'merges.txt'
- if not merges_file.is_file():
- return False
- with open(merges_file, 'r', encoding = 'utf-8') as fp:
- first_line = next(fp, '').strip()
- if not first_line.startswith('#'):
- fp.seek(0)
- line_num = 0
- else:
- line_num = 1
- merges = []
- for line in fp:
- line_num += 1
- line = line.strip()
- if not line:
- continue
- parts = line.split(None, 3)
- if len(parts) != 2:
- logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
- continue
- merges.append(f'{parts[0]} {parts[1]}')
- self.merges = merges
- return True
- def _set_special_token(self, typ: str, tid: Any) -> None:
- if not isinstance(tid, int):
- return
- if tid < 0:
- raise ValueError(f'invalid value for special token type {typ}: {tid}')
- if self.n_vocab is None or tid < self.n_vocab:
- if typ in self.special_token_ids:
- return
- self.special_token_ids[typ] = tid
- return
- logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
- def _try_load_from_tokenizer_json(self, path: Path) -> bool:
- tokenizer_file = path / 'tokenizer.json'
- if tokenizer_file.is_file():
- with open(tokenizer_file, encoding = 'utf-8') as f:
- tokenizer = json.load(f)
- if self.load_merges:
- merges = tokenizer.get('model', {}).get('merges')
- if isinstance(merges, list) and merges:
- if isinstance(merges[0], str):
- self.merges = merges
- elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
- # New format since transformers 4.45 to support spaces in merges
- # ref: https://github.com/ggml-org/llama.cpp/issues/9692
- # TODO: internally store as the new format instead of converting to old
- if any(' ' in s for pair in merges for s in pair):
- logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
- self.merges = [
- ' '.join(
- [
- # ensure the spaces are properly encoded
- ''.join(
- chr(ord(c) + 256) if c == ' ' else c
- for c in part
- )
- for part in pair
- ]
- )
- for pair in merges
- ]
- else:
- raise ValueError("Unknown tokenizer merges format")
- added_tokens = tokenizer.get('added_tokens', {})
- else:
- added_tokens = {}
- tokenizer_config_file = path / 'tokenizer_config.json'
- if not tokenizer_config_file.is_file():
- return True
- with open(tokenizer_config_file, encoding = 'utf-8') as f:
- tokenizer_config = json.load(f)
- chat_template = tokenizer_config.get('chat_template')
- if chat_template is None or isinstance(chat_template, (str, list)):
- self.chat_template = chat_template
- else:
- logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
- for typ in self.special_token_types:
- add_entry = tokenizer_config.get(f'add_{typ}_token')
- if isinstance(add_entry, bool):
- self.add_special_token[typ] = add_entry
- entry = tokenizer_config.get(f'{typ}_token')
- if isinstance(entry, str):
- tc_content = entry
- elif isinstance(entry, dict):
- entry_content = entry.get('content')
- if not isinstance(entry_content, str):
- continue
- tc_content = entry_content
- else:
- continue
- # We only need the first match here.
- maybe_token_id = next(
- (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
- None,
- )
- self._set_special_token(typ, maybe_token_id)
- return True
- def _try_load_from_config_json(self, path: Path) -> bool:
- config_file = path / 'config.json'
- if not config_file.is_file():
- return False
- with open(config_file, encoding = 'utf-8') as f:
- config = json.load(f)
- for typ in self.special_token_types:
- self._set_special_token(typ, config.get(f'{typ}_token_id'))
- return True
- @runtime_checkable
- class BaseVocab(Protocol):
- tokenizer_model: ClassVar[str]
- name: ClassVar[str]
- @runtime_checkable
- class Vocab(BaseVocab, Protocol):
- vocab_size: int
- added_tokens_dict: dict[str, int]
- added_tokens_list: list[str]
- fname_tokenizer: Path
- def __init__(self, base_path: Path): ...
- def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
- class NoVocab(BaseVocab):
- tokenizer_model = "no_vocab"
- name = "no_vocab"
- def __repr__(self) -> str:
- return "<NoVocab for a model without integrated vocabulary>"
- class BpeVocab(Vocab):
- tokenizer_model = "gpt2"
- name = "bpe"
- def __init__(self, base_path: Path):
- added_tokens: dict[str, int] = {}
- if (fname_tokenizer := base_path / 'vocab.json').exists():
- # "slow" tokenizer
- with open(fname_tokenizer, encoding="utf-8") as f:
- self.vocab = json.load(f)
- try:
- # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
- with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
- added_tokens = json.load(f)
- except FileNotFoundError:
- pass
- else:
- # "fast" tokenizer
- fname_tokenizer = base_path / 'tokenizer.json'
- # if this fails, FileNotFoundError propagates to caller
- with open(fname_tokenizer, encoding="utf-8") as f:
- tokenizer_json = json.load(f)
- tokenizer_model: dict[str, Any] = tokenizer_json['model']
- if (
- tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
- or tokenizer_json['decoder']['type'] != 'ByteLevel'
- ):
- raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
- self.vocab = tokenizer_model["vocab"]
- if (added := tokenizer_json.get('added_tokens')) is not None:
- # Added tokens here can be duplicates of the main vocabulary.
- added_tokens = {item['content']: item['id']
- for item in added
- if item['content'] not in self.vocab}
- vocab_size = len(self.vocab)
- expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
- actual_ids = sorted(added_tokens.values())
- if expected_ids != actual_ids:
- expected_end_id = vocab_size + len(actual_ids) - 1
- raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
- f"{vocab_size} - {expected_end_id}; got {actual_ids}")
- items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
- self.added_tokens_dict = added_tokens
- self.added_tokens_list = [text for (text, idx) in items]
- self.vocab_size_base = vocab_size
- self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
- self.fname_tokenizer = fname_tokenizer
- def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
- for i, _ in enumerate(self.vocab):
- yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
- def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- for text in self.added_tokens_list:
- score = -1000.0
- yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
- def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- yield from self.bpe_tokens()
- yield from self.added_tokens()
- def __repr__(self) -> str:
- return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
- class SentencePieceVocab(Vocab):
- tokenizer_model = "llama"
- name = "spm"
- def __init__(self, base_path: Path):
- added_tokens: dict[str, int] = {}
- if (fname_tokenizer := base_path / 'tokenizer.model').exists():
- # normal location
- try:
- with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
- added_tokens = json.load(f)
- except FileNotFoundError:
- pass
- elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
- # not found in alternate location either
- raise FileNotFoundError('Cannot find tokenizer.model')
- self.sentencepiece_tokenizer = SentencePieceProcessor()
- self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
- vocab_size = self.sentencepiece_tokenizer.vocab_size()
- new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
- expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
- actual_new_ids = sorted(new_tokens.keys())
- if expected_new_ids != actual_new_ids:
- raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
- # Token pieces that were added to the base vocabulary.
- self.added_tokens_dict = added_tokens
- self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
- self.vocab_size_base = vocab_size
- self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
- self.fname_tokenizer = fname_tokenizer
- def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- tokenizer = self.sentencepiece_tokenizer
- for i in range(tokenizer.vocab_size()):
- piece = tokenizer.IdToPiece(i)
- text = piece.encode("utf-8")
- score: float = tokenizer.GetScore(i)
- toktype = gguf.TokenType.NORMAL
- if tokenizer.IsUnknown(i):
- toktype = gguf.TokenType.UNKNOWN
- if tokenizer.IsControl(i):
- toktype = gguf.TokenType.CONTROL
- # NOTE: I think added_tokens are user defined.
- # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
- # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
- if tokenizer.IsUnused(i):
- toktype = gguf.TokenType.UNUSED
- if tokenizer.IsByte(i):
- toktype = gguf.TokenType.BYTE
- yield text, score, toktype
- def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- for text in self.added_tokens_list:
- score = -1000.0
- yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
- def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- yield from self.sentencepiece_tokens()
- yield from self.added_tokens()
- def __repr__(self) -> str:
- return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
- class LlamaHfVocab(Vocab):
- tokenizer_model = "llama"
- name = "hfft"
- def __init__(self, base_path: Path):
- fname_tokenizer = base_path / 'tokenizer.json'
- # if this fails, FileNotFoundError propagates to caller
- with open(fname_tokenizer, encoding='utf-8') as f:
- tokenizer_json = json.load(f)
- # pre-check so we know if we need transformers
- tokenizer_model: dict[str, Any] = tokenizer_json['model']
- is_llama3 = (
- tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
- and not tokenizer_model.get('byte_fallback', True)
- )
- if is_llama3:
- raise TypeError('Llama 3 must be converted with BpeVocab')
- if not is_llama3 and (
- tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
- or tokenizer_json['decoder']['type'] != 'Sequence'
- ):
- raise FileNotFoundError('Cannot find Llama BPE tokenizer')
- try:
- from transformers import AutoTokenizer
- except ImportError as e:
- raise ImportError(
- "To use LlamaHfVocab, please install the `transformers` package. "
- "You can install it with `pip install transformers`."
- ) from e
- # Allow the tokenizer to default to slow or fast versions.
- # Explicitly set tokenizer to use local paths.
- self.tokenizer = AutoTokenizer.from_pretrained(
- base_path,
- cache_dir=base_path,
- local_files_only=True,
- )
- assert self.tokenizer.is_fast # assume tokenizer.json is used
- # Initialize lists and dictionaries for added tokens
- self.added_tokens_list = []
- self.added_tokens_dict = dict()
- self.added_tokens_ids = set()
- # Process added tokens
- for tok, tokidx in sorted(
- self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
- ):
- # Only consider added tokens that are not in the base vocabulary
- if tokidx >= self.tokenizer.vocab_size:
- self.added_tokens_list.append(tok)
- self.added_tokens_dict[tok] = tokidx
- self.added_tokens_ids.add(tokidx)
- # Store special tokens and their IDs
- self.specials = {
- tok: self.tokenizer.get_vocab()[tok]
- for tok in self.tokenizer.all_special_tokens
- }
- self.special_ids = set(self.tokenizer.all_special_ids)
- # Set vocabulary sizes
- self.vocab_size_base = self.tokenizer.vocab_size
- self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
- self.fname_tokenizer = fname_tokenizer
- def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- reverse_vocab = {
- id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
- }
- for token_id in range(self.vocab_size_base):
- # Skip processing added tokens here
- if token_id in self.added_tokens_ids:
- continue
- # Convert token text to bytes
- token_text = reverse_vocab[token_id].encode("utf-8")
- # Yield token text, score, and type
- yield token_text, self.get_token_score(token_id), self.get_token_type(
- token_id, token_text, self.special_ids # Reuse already stored special IDs
- )
- def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
- # Special case for byte tokens
- if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
- return gguf.TokenType.BYTE
- # Determine token type based on whether it's a special token
- return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
- def get_token_score(self, token_id: int) -> float:
- # Placeholder for actual logic to determine the token's score
- # This needs to be implemented based on specific requirements
- return -1000.0 # Default score
- def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- for text in self.added_tokens_list:
- if text in self.specials:
- toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
- score = self.get_token_score(self.specials[text])
- else:
- toktype = gguf.TokenType.USER_DEFINED
- score = -1000.0
- yield text.encode("utf-8"), score, toktype
- def has_newline_token(self):
- return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
- def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
- yield from self.hf_tokens()
- yield from self.added_tokens()
- def __repr__(self) -> str:
- return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|